cs7324 Lab 5 - Wide and Deep Networks¶

Chip Henderson - 48996654¶

For this lab I am changing my dataset to the Breast Cancer Gene Expression Profiles (METABRIC) dataset (https://www.kaggle.com/datasets/raghadalharbi/breast-cancer-gene-expression-profiles-metabric).

I wanted a dataset with more feature data for this lab. Additionally, having a history of breast cancer in my family I was curious about the dataset which makes the lab more interesting. I'll be upfront about the fact that my lab is somewhat morbid. I'll be predicting whether a patient is likely to live or die based on the characteristics of their cancer. However, based on a personal family experience, a model like this may have helped drive a more realistic discussion on treatment plan. That aside, thanks to my wife who is an Oncology nurse practicioner and helped me understand some of these terms.

1. Preparation¶

Preprocessing and Class Variable Definition¶

In [ ]:
import pandas as pd

bc_df = pd.read_csv(r'c:\users\chip\source\repos\cs7324_code\Lab 5\METABRIC_RNA_Mutation.csv', sep=',')
bc_df.shape
C:\Users\Chip\AppData\Local\Temp\ipykernel_17188\103049857.py:3: DtypeWarning: Columns (678,688,690,692) have mixed types. Specify dtype option on import or set low_memory=False.
  bc_df = pd.read_csv(r'c:\users\chip\source\repos\cs7324_code\Lab 5\METABRIC_RNA_Mutation.csv', sep=',')
Out[ ]:
(1904, 693)
In [ ]:
# Source: https://stackoverflow.com/questions/34537048/how-to-count-nan-values-in-a-pandas-dataframe
bc_df.info(verbose=True, show_counts=True)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1904 entries, 0 to 1903
Data columns (total 693 columns):
 #    Column                          Non-Null Count  Dtype  
---   ------                          --------------  -----  
 0    patient_id                      1904 non-null   int64  
 1    age_at_diagnosis                1904 non-null   float64
 2    type_of_breast_surgery          1882 non-null   object 
 3    cancer_type                     1904 non-null   object 
 4    cancer_type_detailed            1889 non-null   object 
 5    cellularity                     1850 non-null   object 
 6    chemotherapy                    1904 non-null   int64  
 7    pam50_+_claudin-low_subtype     1904 non-null   object 
 8    cohort                          1904 non-null   float64
 9    er_status_measured_by_ihc       1874 non-null   object 
 10   er_status                       1904 non-null   object 
 11   neoplasm_histologic_grade       1832 non-null   float64
 12   her2_status_measured_by_snp6    1904 non-null   object 
 13   her2_status                     1904 non-null   object 
 14   tumor_other_histologic_subtype  1889 non-null   object 
 15   hormone_therapy                 1904 non-null   int64  
 16   inferred_menopausal_state       1904 non-null   object 
 17   integrative_cluster             1904 non-null   object 
 18   primary_tumor_laterality        1798 non-null   object 
 19   lymph_nodes_examined_positive   1904 non-null   float64
 20   mutation_count                  1859 non-null   float64
 21   nottingham_prognostic_index     1904 non-null   float64
 22   oncotree_code                   1889 non-null   object 
 23   overall_survival_months         1904 non-null   float64
 24   overall_survival                1904 non-null   int64  
 25   pr_status                       1904 non-null   object 
 26   radio_therapy                   1904 non-null   int64  
 27   3-gene_classifier_subtype       1700 non-null   object 
 28   tumor_size                      1884 non-null   float64
 29   tumor_stage                     1403 non-null   float64
 30   death_from_cancer               1903 non-null   object 
 31   brca1                           1904 non-null   float64
 32   brca2                           1904 non-null   float64
 33   palb2                           1904 non-null   float64
 34   pten                            1904 non-null   float64
 35   tp53                            1904 non-null   float64
 36   atm                             1904 non-null   float64
 37   cdh1                            1904 non-null   float64
 38   chek2                           1904 non-null   float64
 39   nbn                             1904 non-null   float64
 40   nf1                             1904 non-null   float64
 41   stk11                           1904 non-null   float64
 42   bard1                           1904 non-null   float64
 43   mlh1                            1904 non-null   float64
 44   msh2                            1904 non-null   float64
 45   msh6                            1904 non-null   float64
 46   pms2                            1904 non-null   float64
 47   epcam                           1904 non-null   float64
 48   rad51c                          1904 non-null   float64
 49   rad51d                          1904 non-null   float64
 50   rad50                           1904 non-null   float64
 51   rb1                             1904 non-null   float64
 52   rbl1                            1904 non-null   float64
 53   rbl2                            1904 non-null   float64
 54   ccna1                           1904 non-null   float64
 55   ccnb1                           1904 non-null   float64
 56   cdk1                            1904 non-null   float64
 57   ccne1                           1904 non-null   float64
 58   cdk2                            1904 non-null   float64
 59   cdc25a                          1904 non-null   float64
 60   ccnd1                           1904 non-null   float64
 61   cdk4                            1904 non-null   float64
 62   cdk6                            1904 non-null   float64
 63   ccnd2                           1904 non-null   float64
 64   cdkn2a                          1904 non-null   float64
 65   cdkn2b                          1904 non-null   float64
 66   myc                             1904 non-null   float64
 67   cdkn1a                          1904 non-null   float64
 68   cdkn1b                          1904 non-null   float64
 69   e2f1                            1904 non-null   float64
 70   e2f2                            1904 non-null   float64
 71   e2f3                            1904 non-null   float64
 72   e2f4                            1904 non-null   float64
 73   e2f5                            1904 non-null   float64
 74   e2f6                            1904 non-null   float64
 75   e2f7                            1904 non-null   float64
 76   e2f8                            1904 non-null   float64
 77   src                             1904 non-null   float64
 78   jak1                            1904 non-null   float64
 79   jak2                            1904 non-null   float64
 80   stat1                           1904 non-null   float64
 81   stat2                           1904 non-null   float64
 82   stat3                           1904 non-null   float64
 83   stat5a                          1904 non-null   float64
 84   stat5b                          1904 non-null   float64
 85   mdm2                            1904 non-null   float64
 86   tp53bp1                         1904 non-null   float64
 87   adam10                          1904 non-null   float64
 88   adam17                          1904 non-null   float64
 89   aph1a                           1904 non-null   float64
 90   aph1b                           1904 non-null   float64
 91   arrdc1                          1904 non-null   float64
 92   cir1                            1904 non-null   float64
 93   ctbp1                           1904 non-null   float64
 94   ctbp2                           1904 non-null   float64
 95   cul1                            1904 non-null   float64
 96   dll1                            1904 non-null   float64
 97   dll3                            1904 non-null   float64
 98   dll4                            1904 non-null   float64
 99   dtx1                            1904 non-null   float64
 100  dtx2                            1904 non-null   float64
 101  dtx3                            1904 non-null   float64
 102  dtx4                            1904 non-null   float64
 103  ep300                           1904 non-null   float64
 104  fbxw7                           1904 non-null   float64
 105  hdac1                           1904 non-null   float64
 106  hdac2                           1904 non-null   float64
 107  hes1                            1904 non-null   float64
 108  hes5                            1904 non-null   float64
 109  heyl                            1904 non-null   float64
 110  itch                            1904 non-null   float64
 111  jag1                            1904 non-null   float64
 112  jag2                            1904 non-null   float64
 113  kdm5a                           1904 non-null   float64
 114  lfng                            1904 non-null   float64
 115  maml1                           1904 non-null   float64
 116  maml2                           1904 non-null   float64
 117  maml3                           1904 non-null   float64
 118  ncor2                           1904 non-null   float64
 119  ncstn                           1904 non-null   float64
 120  notch1                          1904 non-null   float64
 121  notch2                          1904 non-null   float64
 122  notch3                          1904 non-null   float64
 123  nrarp                           1904 non-null   float64
 124  numb                            1904 non-null   float64
 125  numbl                           1904 non-null   float64
 126  psen1                           1904 non-null   float64
 127  psen2                           1904 non-null   float64
 128  psenen                          1904 non-null   float64
 129  rbpj                            1904 non-null   float64
 130  rbpjl                           1904 non-null   float64
 131  rfng                            1904 non-null   float64
 132  snw1                            1904 non-null   float64
 133  spen                            1904 non-null   float64
 134  hes2                            1904 non-null   float64
 135  hes4                            1904 non-null   float64
 136  hes7                            1904 non-null   float64
 137  hey1                            1904 non-null   float64
 138  hey2                            1904 non-null   float64
 139  acvr1                           1904 non-null   float64
 140  acvr1b                          1904 non-null   float64
 141  acvr1c                          1904 non-null   float64
 142  acvr2a                          1904 non-null   float64
 143  acvr2b                          1904 non-null   float64
 144  acvrl1                          1904 non-null   float64
 145  akt1                            1904 non-null   float64
 146  akt1s1                          1904 non-null   float64
 147  akt2                            1904 non-null   float64
 148  apaf1                           1904 non-null   float64
 149  arl11                           1904 non-null   float64
 150  atr                             1904 non-null   float64
 151  aurka                           1904 non-null   float64
 152  bad                             1904 non-null   float64
 153  bcl2                            1904 non-null   float64
 154  bcl2l1                          1904 non-null   float64
 155  bmp10                           1904 non-null   float64
 156  bmp15                           1904 non-null   float64
 157  bmp2                            1904 non-null   float64
 158  bmp3                            1904 non-null   float64
 159  bmp4                            1904 non-null   float64
 160  bmp5                            1904 non-null   float64
 161  bmp6                            1904 non-null   float64
 162  bmp7                            1904 non-null   float64
 163  bmpr1a                          1904 non-null   float64
 164  bmpr1b                          1904 non-null   float64
 165  bmpr2                           1904 non-null   float64
 166  braf                            1904 non-null   float64
 167  casp10                          1904 non-null   float64
 168  casp3                           1904 non-null   float64
 169  casp6                           1904 non-null   float64
 170  casp7                           1904 non-null   float64
 171  casp8                           1904 non-null   float64
 172  casp9                           1904 non-null   float64
 173  chek1                           1904 non-null   float64
 174  csf1                            1904 non-null   float64
 175  csf1r                           1904 non-null   float64
 176  cxcl8                           1904 non-null   float64
 177  cxcr1                           1904 non-null   float64
 178  cxcr2                           1904 non-null   float64
 179  dab2                            1904 non-null   float64
 180  diras3                          1904 non-null   float64
 181  dlec1                           1904 non-null   float64
 182  dph1                            1904 non-null   float64
 183  egfr                            1904 non-null   float64
 184  eif4e                           1904 non-null   float64
 185  eif4ebp1                        1904 non-null   float64
 186  eif5a2                          1904 non-null   float64
 187  erbb2                           1904 non-null   float64
 188  erbb3                           1904 non-null   float64
 189  erbb4                           1904 non-null   float64
 190  fas                             1904 non-null   float64
 191  fgf1                            1904 non-null   float64
 192  fgfr1                           1904 non-null   float64
 193  folr1                           1904 non-null   float64
 194  folr2                           1904 non-null   float64
 195  folr3                           1904 non-null   float64
 196  foxo1                           1904 non-null   float64
 197  foxo3                           1904 non-null   float64
 198  gdf11                           1904 non-null   float64
 199  gdf2                            1904 non-null   float64
 200  gsk3b                           1904 non-null   float64
 201  hif1a                           1904 non-null   float64
 202  hla-g                           1904 non-null   float64
 203  hras                            1904 non-null   float64
 204  igf1                            1904 non-null   float64
 205  igf1r                           1904 non-null   float64
 206  inha                            1904 non-null   float64
 207  inhba                           1904 non-null   float64
 208  inhbc                           1904 non-null   float64
 209  itgav                           1904 non-null   float64
 210  itgb3                           1904 non-null   float64
 211  izumo1r                         1904 non-null   float64
 212  kdr                             1904 non-null   float64
 213  kit                             1904 non-null   float64
 214  kras                            1904 non-null   float64
 215  map2k1                          1904 non-null   float64
 216  map2k2                          1904 non-null   float64
 217  map2k3                          1904 non-null   float64
 218  map2k4                          1904 non-null   float64
 219  map2k5                          1904 non-null   float64
 220  map3k1                          1904 non-null   float64
 221  map3k3                          1904 non-null   float64
 222  map3k4                          1904 non-null   float64
 223  map3k5                          1904 non-null   float64
 224  mapk1                           1904 non-null   float64
 225  mapk12                          1904 non-null   float64
 226  mapk14                          1904 non-null   float64
 227  mapk3                           1904 non-null   float64
 228  mapk4                           1904 non-null   float64
 229  mapk6                           1904 non-null   float64
 230  mapk7                           1904 non-null   float64
 231  mapk8                           1904 non-null   float64
 232  mapk9                           1904 non-null   float64
 233  mdc1                            1904 non-null   float64
 234  mlst8                           1904 non-null   float64
 235  mmp1                            1904 non-null   float64
 236  mmp10                           1904 non-null   float64
 237  mmp11                           1904 non-null   float64
 238  mmp12                           1904 non-null   float64
 239  mmp13                           1904 non-null   float64
 240  mmp14                           1904 non-null   float64
 241  mmp15                           1904 non-null   float64
 242  mmp16                           1904 non-null   float64
 243  mmp17                           1904 non-null   float64
 244  mmp19                           1904 non-null   float64
 245  mmp2                            1904 non-null   float64
 246  mmp21                           1904 non-null   float64
 247  mmp23b                          1904 non-null   float64
 248  mmp24                           1904 non-null   float64
 249  mmp25                           1904 non-null   float64
 250  mmp26                           1904 non-null   float64
 251  mmp27                           1904 non-null   float64
 252  mmp28                           1904 non-null   float64
 253  mmp3                            1904 non-null   float64
 254  mmp7                            1904 non-null   float64
 255  mmp9                            1904 non-null   float64
 256  mtor                            1904 non-null   float64
 257  nfkb1                           1904 non-null   float64
 258  nfkb2                           1904 non-null   float64
 259  opcml                           1904 non-null   float64
 260  pdgfa                           1904 non-null   float64
 261  pdgfb                           1904 non-null   float64
 262  pdgfra                          1904 non-null   float64
 263  pdgfrb                          1904 non-null   float64
 264  pdpk1                           1904 non-null   float64
 265  peg3                            1904 non-null   float64
 266  pik3ca                          1904 non-null   float64
 267  pik3r1                          1904 non-null   float64
 268  pik3r2                          1904 non-null   float64
 269  plagl1                          1904 non-null   float64
 270  ptk2                            1904 non-null   float64
 271  rab25                           1904 non-null   float64
 272  rad51                           1904 non-null   float64
 273  raf1                            1904 non-null   float64
 274  rassf1                          1904 non-null   float64
 275  rheb                            1904 non-null   float64
 276  rictor                          1904 non-null   float64
 277  rps6                            1904 non-null   float64
 278  rps6ka1                         1904 non-null   float64
 279  rps6ka2                         1904 non-null   float64
 280  rps6kb1                         1904 non-null   float64
 281  rps6kb2                         1904 non-null   float64
 282  rptor                           1904 non-null   float64
 283  slc19a1                         1904 non-null   float64
 284  smad1                           1904 non-null   float64
 285  smad2                           1904 non-null   float64
 286  smad3                           1904 non-null   float64
 287  smad4                           1904 non-null   float64
 288  smad5                           1904 non-null   float64
 289  smad6                           1904 non-null   float64
 290  smad7                           1904 non-null   float64
 291  smad9                           1904 non-null   float64
 292  sptbn1                          1904 non-null   float64
 293  terc                            1904 non-null   float64
 294  tert                            1904 non-null   float64
 295  tgfb1                           1904 non-null   float64
 296  tgfb2                           1904 non-null   float64
 297  tgfb3                           1904 non-null   float64
 298  tgfbr1                          1904 non-null   float64
 299  tgfbr2                          1904 non-null   float64
 300  tgfbr3                          1904 non-null   float64
 301  tsc1                            1904 non-null   float64
 302  tsc2                            1904 non-null   float64
 303  vegfa                           1904 non-null   float64
 304  vegfb                           1904 non-null   float64
 305  wfdc2                           1904 non-null   float64
 306  wwox                            1904 non-null   float64
 307  zfyve9                          1904 non-null   float64
 308  arid1a                          1904 non-null   float64
 309  arid1b                          1904 non-null   float64
 310  cbfb                            1904 non-null   float64
 311  gata3                           1904 non-null   float64
 312  kmt2c                           1904 non-null   float64
 313  kmt2d                           1904 non-null   float64
 314  myh9                            1904 non-null   float64
 315  ncor1                           1904 non-null   float64
 316  pde4dip                         1904 non-null   float64
 317  ptprd                           1904 non-null   float64
 318  ros1                            1904 non-null   float64
 319  runx1                           1904 non-null   float64
 320  tbx3                            1904 non-null   float64
 321  abcb1                           1904 non-null   float64
 322  abcb11                          1904 non-null   float64
 323  abcc1                           1904 non-null   float64
 324  abcc10                          1904 non-null   float64
 325  bbc3                            1904 non-null   float64
 326  bmf                             1904 non-null   float64
 327  cyp2c8                          1904 non-null   float64
 328  cyp3a4                          1904 non-null   float64
 329  fgf2                            1904 non-null   float64
 330  fn1                             1904 non-null   float64
 331  map2                            1904 non-null   float64
 332  map4                            1904 non-null   float64
 333  mapt                            1904 non-null   float64
 334  nr1i2                           1904 non-null   float64
 335  slco1b3                         1904 non-null   float64
 336  tubb1                           1904 non-null   float64
 337  tubb4a                          1904 non-null   float64
 338  tubb4b                          1904 non-null   float64
 339  twist1                          1904 non-null   float64
 340  adgra2                          1904 non-null   float64
 341  afdn                            1904 non-null   float64
 342  aff2                            1904 non-null   float64
 343  agmo                            1904 non-null   float64
 344  agtr2                           1904 non-null   float64
 345  ahnak                           1904 non-null   float64
 346  ahnak2                          1904 non-null   float64
 347  akap9                           1904 non-null   float64
 348  alk                             1904 non-null   float64
 349  apc                             1904 non-null   float64
 350  arid2                           1904 non-null   float64
 351  arid5b                          1904 non-null   float64
 352  asxl1                           1904 non-null   float64
 353  asxl2                           1904 non-null   float64
 354  bap1                            1904 non-null   float64
 355  bcas3                           1904 non-null   float64
 356  birc6                           1904 non-null   float64
 357  cacna2d3                        1904 non-null   float64
 358  ccnd3                           1904 non-null   float64
 359  chd1                            1904 non-null   float64
 360  clk3                            1904 non-null   float64
 361  clrn2                           1904 non-null   float64
 362  col12a1                         1904 non-null   float64
 363  col22a1                         1904 non-null   float64
 364  col6a3                          1904 non-null   float64
 365  ctcf                            1904 non-null   float64
 366  ctnna1                          1904 non-null   float64
 367  ctnna3                          1904 non-null   float64
 368  dnah11                          1904 non-null   float64
 369  dnah2                           1904 non-null   float64
 370  dnah5                           1904 non-null   float64
 371  dtwd2                           1904 non-null   float64
 372  fam20c                          1904 non-null   float64
 373  fanca                           1904 non-null   float64
 374  fancd2                          1904 non-null   float64
 375  flt3                            1904 non-null   float64
 376  foxp1                           1904 non-null   float64
 377  frmd3                           1904 non-null   float64
 378  gh1                             1904 non-null   float64
 379  gldc                            1904 non-null   float64
 380  gpr32                           1904 non-null   float64
 381  gps2                            1904 non-null   float64
 382  hdac9                           1904 non-null   float64
 383  herc2                           1904 non-null   float64
 384  hist1h2bc                       1904 non-null   float64
 385  kdm3a                           1904 non-null   float64
 386  kdm6a                           1904 non-null   float64
 387  klrg1                           1904 non-null   float64
 388  l1cam                           1904 non-null   float64
 389  lama2                           1904 non-null   float64
 390  lamb3                           1904 non-null   float64
 391  large1                          1904 non-null   float64
 392  ldlrap1                         1904 non-null   float64
 393  lifr                            1904 non-null   float64
 394  lipi                            1904 non-null   float64
 395  magea8                          1904 non-null   float64
 396  map3k10                         1904 non-null   float64
 397  map3k13                         1904 non-null   float64
 398  men1                            1904 non-null   float64
 399  mtap                            1904 non-null   float64
 400  muc16                           1904 non-null   float64
 401  myo1a                           1904 non-null   float64
 402  myo3a                           1904 non-null   float64
 403  ncoa3                           1904 non-null   float64
 404  nek1                            1904 non-null   float64
 405  nf2                             1904 non-null   float64
 406  npnt                            1904 non-null   float64
 407  nr2f1                           1904 non-null   float64
 408  nr3c1                           1904 non-null   float64
 409  nras                            1904 non-null   float64
 410  nrg3                            1904 non-null   float64
 411  nt5e                            1904 non-null   float64
 412  or6a2                           1904 non-null   float64
 413  palld                           1904 non-null   float64
 414  pbrm1                           1904 non-null   float64
 415  ppp2cb                          1904 non-null   float64
 416  ppp2r2a                         1904 non-null   float64
 417  prkacg                          1904 non-null   float64
 418  prkce                           1904 non-null   float64
 419  prkcq                           1904 non-null   float64
 420  prkcz                           1904 non-null   float64
 421  prkg1                           1904 non-null   float64
 422  prps2                           1904 non-null   float64
 423  prr16                           1904 non-null   float64
 424  ptpn22                          1904 non-null   float64
 425  ptprm                           1904 non-null   float64
 426  rasgef1b                        1904 non-null   float64
 427  rpgr                            1904 non-null   float64
 428  ryr2                            1904 non-null   float64
 429  sbno1                           1904 non-null   float64
 430  setd1a                          1904 non-null   float64
 431  setd2                           1904 non-null   float64
 432  setdb1                          1904 non-null   float64
 433  sf3b1                           1904 non-null   float64
 434  sgcd                            1904 non-null   float64
 435  shank2                          1904 non-null   float64
 436  siah1                           1904 non-null   float64
 437  sik1                            1904 non-null   float64
 438  sik2                            1904 non-null   float64
 439  smarcb1                         1904 non-null   float64
 440  smarcc1                         1904 non-null   float64
 441  smarcc2                         1904 non-null   float64
 442  smarcd1                         1904 non-null   float64
 443  spaca1                          1904 non-null   float64
 444  stab2                           1904 non-null   float64
 445  stmn2                           1904 non-null   float64
 446  syne1                           1904 non-null   float64
 447  taf1                            1904 non-null   float64
 448  taf4b                           1904 non-null   float64
 449  tbl1xr1                         1904 non-null   float64
 450  tg                              1904 non-null   float64
 451  thada                           1904 non-null   float64
 452  thsd7a                          1904 non-null   float64
 453  ttyh1                           1904 non-null   float64
 454  ubr5                            1904 non-null   float64
 455  ush2a                           1904 non-null   float64
 456  usp9x                           1904 non-null   float64
 457  utrn                            1904 non-null   float64
 458  zfp36l1                         1904 non-null   float64
 459  ackr3                           1904 non-null   float64
 460  akr1c1                          1904 non-null   float64
 461  akr1c2                          1904 non-null   float64
 462  akr1c3                          1904 non-null   float64
 463  akr1c4                          1904 non-null   float64
 464  akt3                            1904 non-null   float64
 465  ar                              1904 non-null   float64
 466  bche                            1904 non-null   float64
 467  cdk8                            1904 non-null   float64
 468  cdkn2c                          1904 non-null   float64
 469  cyb5a                           1904 non-null   float64
 470  cyp11a1                         1904 non-null   float64
 471  cyp11b2                         1904 non-null   float64
 472  cyp17a1                         1904 non-null   float64
 473  cyp19a1                         1904 non-null   float64
 474  cyp21a2                         1904 non-null   float64
 475  cyp3a43                         1904 non-null   float64
 476  cyp3a5                          1904 non-null   float64
 477  cyp3a7                          1904 non-null   float64
 478  ddc                             1904 non-null   float64
 479  hes6                            1904 non-null   float64
 480  hsd17b1                         1904 non-null   float64
 481  hsd17b10                        1904 non-null   float64
 482  hsd17b11                        1904 non-null   float64
 483  hsd17b12                        1904 non-null   float64
 484  hsd17b13                        1904 non-null   float64
 485  hsd17b14                        1904 non-null   float64
 486  hsd17b2                         1904 non-null   float64
 487  hsd17b3                         1904 non-null   float64
 488  hsd17b4                         1904 non-null   float64
 489  hsd17b6                         1904 non-null   float64
 490  hsd17b7                         1904 non-null   float64
 491  hsd17b8                         1904 non-null   float64
 492  hsd3b1                          1904 non-null   float64
 493  hsd3b2                          1904 non-null   float64
 494  hsd3b7                          1904 non-null   float64
 495  mecom                           1904 non-null   float64
 496  met                             1904 non-null   float64
 497  ncoa2                           1904 non-null   float64
 498  nrip1                           1904 non-null   float64
 499  pik3r3                          1904 non-null   float64
 500  prkci                           1904 non-null   float64
 501  prkd1                           1904 non-null   float64
 502  ran                             1904 non-null   float64
 503  rdh5                            1904 non-null   float64
 504  sdc4                            1904 non-null   float64
 505  serpini1                        1904 non-null   float64
 506  shbg                            1904 non-null   float64
 507  slc29a1                         1904 non-null   float64
 508  sox9                            1904 non-null   float64
 509  spry2                           1904 non-null   float64
 510  srd5a1                          1904 non-null   float64
 511  srd5a2                          1904 non-null   float64
 512  srd5a3                          1904 non-null   float64
 513  st7                             1904 non-null   float64
 514  star                            1904 non-null   float64
 515  tnk2                            1904 non-null   float64
 516  tulp4                           1904 non-null   float64
 517  ugt2b15                         1904 non-null   float64
 518  ugt2b17                         1904 non-null   float64
 519  ugt2b7                          1904 non-null   float64
 520  pik3ca_mut                      1904 non-null   object 
 521  tp53_mut                        1904 non-null   object 
 522  muc16_mut                       1904 non-null   object 
 523  ahnak2_mut                      1904 non-null   object 
 524  kmt2c_mut                       1904 non-null   object 
 525  syne1_mut                       1904 non-null   object 
 526  gata3_mut                       1904 non-null   object 
 527  map3k1_mut                      1904 non-null   object 
 528  ahnak_mut                       1904 non-null   object 
 529  dnah11_mut                      1904 non-null   object 
 530  cdh1_mut                        1904 non-null   object 
 531  dnah2_mut                       1904 non-null   object 
 532  kmt2d_mut                       1904 non-null   object 
 533  ush2a_mut                       1904 non-null   object 
 534  ryr2_mut                        1904 non-null   object 
 535  dnah5_mut                       1904 non-null   object 
 536  herc2_mut                       1904 non-null   object 
 537  pde4dip_mut                     1904 non-null   object 
 538  akap9_mut                       1904 non-null   object 
 539  tg_mut                          1904 non-null   object 
 540  birc6_mut                       1904 non-null   object 
 541  utrn_mut                        1904 non-null   object 
 542  tbx3_mut                        1904 non-null   object 
 543  col6a3_mut                      1904 non-null   object 
 544  arid1a_mut                      1904 non-null   object 
 545  lama2_mut                       1904 non-null   object 
 546  notch1_mut                      1904 non-null   object 
 547  cbfb_mut                        1904 non-null   object 
 548  ncor2_mut                       1904 non-null   object 
 549  col12a1_mut                     1904 non-null   object 
 550  col22a1_mut                     1904 non-null   object 
 551  pten_mut                        1904 non-null   object 
 552  akt1_mut                        1904 non-null   object 
 553  atr_mut                         1904 non-null   object 
 554  thada_mut                       1904 non-null   object 
 555  ncor1_mut                       1904 non-null   object 
 556  stab2_mut                       1904 non-null   object 
 557  myh9_mut                        1904 non-null   object 
 558  runx1_mut                       1904 non-null   object 
 559  nf1_mut                         1904 non-null   object 
 560  map2k4_mut                      1904 non-null   object 
 561  ros1_mut                        1904 non-null   object 
 562  lamb3_mut                       1904 non-null   object 
 563  arid1b_mut                      1904 non-null   object 
 564  erbb2_mut                       1904 non-null   object 
 565  sf3b1_mut                       1904 non-null   object 
 566  shank2_mut                      1904 non-null   object 
 567  ep300_mut                       1904 non-null   object 
 568  ptprd_mut                       1904 non-null   object 
 569  usp9x_mut                       1904 non-null   object 
 570  setd2_mut                       1904 non-null   object 
 571  setd1a_mut                      1904 non-null   object 
 572  thsd7a_mut                      1904 non-null   object 
 573  afdn_mut                        1904 non-null   object 
 574  erbb3_mut                       1904 non-null   object 
 575  rb1_mut                         1904 non-null   object 
 576  myo1a_mut                       1904 non-null   object 
 577  alk_mut                         1904 non-null   object 
 578  fanca_mut                       1904 non-null   object 
 579  adgra2_mut                      1904 non-null   object 
 580  ubr5_mut                        1904 non-null   object 
 581  pik3r1_mut                      1904 non-null   object 
 582  myo3a_mut                       1904 non-null   object 
 583  asxl2_mut                       1904 non-null   object 
 584  apc_mut                         1904 non-null   object 
 585  ctcf_mut                        1904 non-null   object 
 586  asxl1_mut                       1904 non-null   object 
 587  fancd2_mut                      1904 non-null   object 
 588  taf1_mut                        1904 non-null   object 
 589  kdm6a_mut                       1904 non-null   object 
 590  ctnna3_mut                      1904 non-null   object 
 591  brca1_mut                       1904 non-null   object 
 592  ptprm_mut                       1904 non-null   object 
 593  foxo3_mut                       1904 non-null   object 
 594  usp28_mut                       1904 non-null   object 
 595  gldc_mut                        1904 non-null   object 
 596  brca2_mut                       1904 non-null   object 
 597  cacna2d3_mut                    1904 non-null   object 
 598  arid2_mut                       1904 non-null   object 
 599  aff2_mut                        1904 non-null   object 
 600  lifr_mut                        1904 non-null   object 
 601  sbno1_mut                       1904 non-null   object 
 602  kdm3a_mut                       1904 non-null   object 
 603  ncoa3_mut                       1904 non-null   object 
 604  bap1_mut                        1904 non-null   object 
 605  l1cam_mut                       1904 non-null   object 
 606  pbrm1_mut                       1904 non-null   object 
 607  chd1_mut                        1904 non-null   object 
 608  jak1_mut                        1904 non-null   object 
 609  setdb1_mut                      1904 non-null   object 
 610  fam20c_mut                      1904 non-null   object 
 611  arid5b_mut                      1904 non-null   object 
 612  egfr_mut                        1904 non-null   object 
 613  map3k10_mut                     1904 non-null   object 
 614  smarcc2_mut                     1904 non-null   object 
 615  erbb4_mut                       1904 non-null   object 
 616  npnt_mut                        1904 non-null   object 
 617  nek1_mut                        1904 non-null   object 
 618  agmo_mut                        1904 non-null   object 
 619  zfp36l1_mut                     1904 non-null   object 
 620  smad4_mut                       1904 non-null   object 
 621  sik1_mut                        1904 non-null   object 
 622  casp8_mut                       1904 non-null   object 
 623  prkcq_mut                       1904 non-null   object 
 624  smarcc1_mut                     1904 non-null   object 
 625  palld_mut                       1904 non-null   object 
 626  dcaf4l2_mut                     1904 non-null   object 
 627  bcas3_mut                       1904 non-null   object 
 628  cdkn1b_mut                      1904 non-null   object 
 629  gps2_mut                        1904 non-null   object 
 630  men1_mut                        1904 non-null   object 
 631  stk11_mut                       1904 non-null   object 
 632  sik2_mut                        1904 non-null   object 
 633  ptpn22_mut                      1904 non-null   object 
 634  brip1_mut                       1904 non-null   object 
 635  flt3_mut                        1904 non-null   object 
 636  nrg3_mut                        1904 non-null   object 
 637  fbxw7_mut                       1904 non-null   object 
 638  ttyh1_mut                       1904 non-null   object 
 639  taf4b_mut                       1904 non-null   object 
 640  or6a2_mut                       1904 non-null   object 
 641  map3k13_mut                     1904 non-null   object 
 642  hdac9_mut                       1904 non-null   object 
 643  prkacg_mut                      1904 non-null   object 
 644  rpgr_mut                        1904 non-null   object 
 645  large1_mut                      1904 non-null   object 
 646  foxp1_mut                       1904 non-null   object 
 647  clk3_mut                        1904 non-null   object 
 648  prkcz_mut                       1904 non-null   object 
 649  lipi_mut                        1904 non-null   object 
 650  ppp2r2a_mut                     1904 non-null   object 
 651  prkce_mut                       1904 non-null   object 
 652  gh1_mut                         1904 non-null   object 
 653  gpr32_mut                       1904 non-null   object 
 654  kras_mut                        1904 non-null   object 
 655  nf2_mut                         1904 non-null   object 
 656  chek2_mut                       1904 non-null   object 
 657  ldlrap1_mut                     1904 non-null   object 
 658  clrn2_mut                       1904 non-null   object 
 659  acvrl1_mut                      1904 non-null   object 
 660  agtr2_mut                       1904 non-null   object 
 661  cdkn2a_mut                      1904 non-null   object 
 662  ctnna1_mut                      1904 non-null   object 
 663  magea8_mut                      1904 non-null   object 
 664  prr16_mut                       1904 non-null   object 
 665  dtwd2_mut                       1904 non-null   object 
 666  akt2_mut                        1904 non-null   object 
 667  braf_mut                        1904 non-null   object 
 668  foxo1_mut                       1904 non-null   object 
 669  nt5e_mut                        1904 non-null   object 
 670  ccnd3_mut                       1904 non-null   object 
 671  nr3c1_mut                       1904 non-null   object 
 672  prkg1_mut                       1904 non-null   object 
 673  tbl1xr1_mut                     1904 non-null   object 
 674  frmd3_mut                       1904 non-null   object 
 675  smad2_mut                       1904 non-null   object 
 676  sgcd_mut                        1904 non-null   object 
 677  spaca1_mut                      1904 non-null   object 
 678  rasgef1b_mut                    1904 non-null   object 
 679  hist1h2bc_mut                   1904 non-null   object 
 680  nr2f1_mut                       1904 non-null   object 
 681  klrg1_mut                       1904 non-null   object 
 682  mbl2_mut                        1904 non-null   object 
 683  mtap_mut                        1904 non-null   object 
 684  ppp2cb_mut                      1904 non-null   object 
 685  smarcd1_mut                     1904 non-null   object 
 686  nras_mut                        1904 non-null   object 
 687  ndfip1_mut                      1904 non-null   object 
 688  hras_mut                        1904 non-null   object 
 689  prps2_mut                       1904 non-null   object 
 690  smarcb1_mut                     1904 non-null   object 
 691  stmn2_mut                       1904 non-null   object 
 692  siah1_mut                       1904 non-null   object 
dtypes: float64(498), int64(5), object(190)
memory usage: 10.1+ MB
In [ ]:
bc_df.head()
Out[ ]:
patient_id age_at_diagnosis type_of_breast_surgery cancer_type cancer_type_detailed cellularity chemotherapy pam50_+_claudin-low_subtype cohort er_status_measured_by_ihc ... mtap_mut ppp2cb_mut smarcd1_mut nras_mut ndfip1_mut hras_mut prps2_mut smarcb1_mut stmn2_mut siah1_mut
0 0 75.65 MASTECTOMY Breast Cancer Breast Invasive Ductal Carcinoma NaN 0 claudin-low 1.0 Positve ... 0 0 0 0 0 0 0 0 0 0
1 2 43.19 BREAST CONSERVING Breast Cancer Breast Invasive Ductal Carcinoma High 0 LumA 1.0 Positve ... 0 0 0 0 0 0 0 0 0 0
2 5 48.87 MASTECTOMY Breast Cancer Breast Invasive Ductal Carcinoma High 1 LumB 1.0 Positve ... 0 0 0 0 0 0 0 0 0 0
3 6 47.68 MASTECTOMY Breast Cancer Breast Mixed Ductal and Lobular Carcinoma Moderate 1 LumB 1.0 Positve ... 0 0 0 0 0 0 0 0 0 0
4 8 76.97 MASTECTOMY Breast Cancer Breast Mixed Ductal and Lobular Carcinoma High 1 LumB 1.0 Positve ... 0 0 0 0 0 0 0 0 0 0

5 rows × 693 columns

In [ ]:
bc_df.describe()
Out[ ]:
patient_id age_at_diagnosis chemotherapy cohort neoplasm_histologic_grade hormone_therapy lymph_nodes_examined_positive mutation_count nottingham_prognostic_index overall_survival_months ... srd5a1 srd5a2 srd5a3 st7 star tnk2 tulp4 ugt2b15 ugt2b17 ugt2b7
count 1904.000000 1904.000000 1904.000000 1904.000000 1832.000000 1904.000000 1904.000000 1859.000000 1904.000000 1904.000000 ... 1.904000e+03 1.904000e+03 1.904000e+03 1.904000e+03 1904.000000 1.904000e+03 1.904000e+03 1.904000e+03 1904.000000 1.904000e+03
mean 3921.982143 61.087054 0.207983 2.643908 2.415939 0.616597 2.002101 5.697687 4.033019 125.121324 ... 4.726891e-07 -3.676471e-07 -9.453782e-07 -1.050420e-07 -0.000002 3.676471e-07 4.726891e-07 7.878151e-07 0.000000 3.731842e-18
std 2358.478332 12.978711 0.405971 1.228615 0.650612 0.486343 4.079993 4.058778 1.144492 76.334148 ... 1.000263e+00 1.000262e+00 1.000262e+00 1.000263e+00 1.000262 1.000264e+00 1.000262e+00 1.000263e+00 1.000262 1.000262e+00
min 0.000000 21.930000 0.000000 1.000000 1.000000 0.000000 0.000000 1.000000 1.000000 0.000000 ... -2.120800e+00 -3.364800e+00 -2.719400e+00 -4.982700e+00 -2.981700 -3.833300e+00 -3.609300e+00 -1.166900e+00 -2.112600 -1.051600e+00
25% 896.500000 51.375000 0.000000 1.000000 2.000000 0.000000 0.000000 3.000000 3.046000 60.825000 ... -6.188500e-01 -6.104750e-01 -6.741750e-01 -6.136750e-01 -0.632900 -6.664750e-01 -7.102000e-01 -5.058250e-01 -0.476200 -7.260000e-01
50% 4730.500000 61.770000 0.000000 3.000000 3.000000 1.000000 0.000000 5.000000 4.042000 115.616667 ... -2.456500e-01 -4.690000e-02 -1.422500e-01 -5.175000e-02 -0.026650 7.000000e-04 -2.980000e-02 -2.885500e-01 -0.133400 -4.248000e-01
75% 5536.250000 70.592500 0.000000 3.000000 3.000000 1.000000 2.000000 7.000000 5.040250 184.716667 ... 3.306000e-01 5.144500e-01 5.146000e-01 5.787750e-01 0.590350 6.429000e-01 5.957250e-01 6.022500e-02 0.270375 4.284000e-01
max 7299.000000 96.290000 1.000000 5.000000 3.000000 1.000000 45.000000 80.000000 6.360000 355.200000 ... 6.534900e+00 1.027030e+01 6.329000e+00 4.571300e+00 12.742300 3.938800e+00 3.833400e+00 1.088490e+01 12.643900 3.284400e+00

8 rows × 503 columns

I'll go ahead and drop na values to help reduce the dataset to more usable data. I'm also dropping the following:

  • 'cancer type' as this entire dataset is breast cancer related and all the values are the same
  • 'cohort' because this is an assigned value and not a measured one that could help prediction
  • 'overall_survival' because this is going to be represented by my classification groups and also would have provided misleading results if the person didn't die of disease

I'm also going to drop the genetic attributes (features 31 through 693) because of the following:

  • They increase the size significantly
  • The values aren't easily understandable
  • They don't necessarily provide meaningful contribution to the classification objective to non-medical professionals
In [ ]:
import copy
# bc_df_full = bc_df.copy() # for use later
bc_df = bc_df.dropna(axis=0)
# bc_df = bc_df.dropna(axis=1)

bc_df = bc_df.drop(bc_df.columns[31:693],axis=1)

features_to_drop = ['cancer_type', 'overall_survival', 'cohort']

bc_df = bc_df.drop(features_to_drop, axis=1)
bc_df.reset_index() # new

# rename column to remove + symbol to avoid any potential datatype issues
bc_df.rename(columns={'pam50_+_claudin-low_subtype':'pam50_plus_claudin-low_subtype'},inplace=True)

bc_df.shape
Out[ ]:
(1092, 28)

That removed roughly 900 instances of data and a few hundered columns which makes my data far easier to work with. To verify I have no more missing data I'll import missingno to visualize it.

In [ ]:
# Referencing code from lecture and in-class examples
import matplotlib
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter('ignore', DeprecationWarning)
%matplotlib inline 

import missingno as mn
# As a departure from lecture code I'm using a bar chart, 
# Matrix version gave me errors regarding 'grid_b' which I wasn't able to resolve
mn.bar(bc_df) 
Out[ ]:
<Axes: >
No description has been provided for this image

No more missing data from my features so we're good to proceed.

In [ ]:
# Determine the remaining datatpyes I'm working with
bc_df.info()
<class 'pandas.core.frame.DataFrame'>
Index: 1092 entries, 1 to 1664
Data columns (total 28 columns):
 #   Column                          Non-Null Count  Dtype  
---  ------                          --------------  -----  
 0   patient_id                      1092 non-null   int64  
 1   age_at_diagnosis                1092 non-null   float64
 2   type_of_breast_surgery          1092 non-null   object 
 3   cancer_type_detailed            1092 non-null   object 
 4   cellularity                     1092 non-null   object 
 5   chemotherapy                    1092 non-null   int64  
 6   pam50_plus_claudin-low_subtype  1092 non-null   object 
 7   er_status_measured_by_ihc       1092 non-null   object 
 8   er_status                       1092 non-null   object 
 9   neoplasm_histologic_grade       1092 non-null   float64
 10  her2_status_measured_by_snp6    1092 non-null   object 
 11  her2_status                     1092 non-null   object 
 12  tumor_other_histologic_subtype  1092 non-null   object 
 13  hormone_therapy                 1092 non-null   int64  
 14  inferred_menopausal_state       1092 non-null   object 
 15  integrative_cluster             1092 non-null   object 
 16  primary_tumor_laterality        1092 non-null   object 
 17  lymph_nodes_examined_positive   1092 non-null   float64
 18  mutation_count                  1092 non-null   float64
 19  nottingham_prognostic_index     1092 non-null   float64
 20  oncotree_code                   1092 non-null   object 
 21  overall_survival_months         1092 non-null   float64
 22  pr_status                       1092 non-null   object 
 23  radio_therapy                   1092 non-null   int64  
 24  3-gene_classifier_subtype       1092 non-null   object 
 25  tumor_size                      1092 non-null   float64
 26  tumor_stage                     1092 non-null   float64
 27  death_from_cancer               1092 non-null   object 
dtypes: float64(8), int64(4), object(16)
memory usage: 247.4+ KB

For my classification objective I want to know what the possible outcomes are

In [ ]:
unique_survivability = list(enumerate(bc_df.death_from_cancer.unique()))

print(unique_survivability)
[(0, 'Living'), (1, 'Died of Disease'), (2, 'Died of Other Causes')]

Died of other causes may not be of value to me because if they didn't either survive or succumb to the disease the data won't provide accurate prediction. Another way to say 'Died of Other Causes' may be that they survived the disease. But, since we don't know how they died, cancer may have ended up being the cause of death given a long enough life span.

I'll start by seeing how many values are in this category then deciding what to do with it.

In [ ]:
survivability_list = [x for x in bc_df['death_from_cancer'] if x == 'Died of Other Causes']

print(f'There are {len(survivability_list)} deaths related to other causes out of {bc_df.shape[0]} instances')
There are 238 deaths related to other causes out of 1092 instances

I'm going to drop these values even though it'll put my number of instances under 1k, smaller than desired. I'll review the class balances later and determine whether I should add some additional samples via oversampling.

In [ ]:
bool_of_unrelated_deaths = (bc_df['death_from_cancer'] == 'Died of Other Causes')
idx_matching = bc_df[bool_of_unrelated_deaths].index
bc_df = bc_df.drop(idx_matching,axis=0)
print(bc_df.shape)
(854, 28)

I'd like to understand how well balanced by two classification groups are, as that will impact how well my model may perform. So I'll check that next.

In [ ]:
import matplotlib
import matplotlib.pyplot as plt

outcomes = bc_df.groupby(['death_from_cancer'])
outcomes.count().plot(kind='pie', 
                      y='patient_id', 
                      autopct='%1.1f%%', 
                      title = "Quantity of Each Outcome")
Out[ ]:
<Axes: title={'center': 'Quantity of Each Outcome'}, ylabel='patient_id'>
No description has been provided for this image

The pie chart is a simple and visually effective way to represent the balance in my two classes. The balance is fairly close but I'd like to have them as even as possible. So I'm going to oversample by repeating several values from 1 (died_of_cancer). This is an appropriate technique because what I'm wanting my model to learn are the characteristics of the cancer that make it more lethal, thereby providing more confirmational guidance for care providers to have open and candid conversation with their patients. If I were to simply recreate instances or impute and alter the result to balance the classes, this would be an inappropriate technique to balance.

There are 854 entries total, so if I want to have each value be equal the number of instances where the individual died needs to be increased by (.567-.433)*854 = ~114 instances.

In [ ]:
# Add instances to balance the classes
bool_of_related_deaths = (bc_df['death_from_cancer'] == 'Died of Disease')
idx_matching_1 = bc_df[bool_of_related_deaths].index
bc_df_died = bc_df.loc[idx_matching_1]
bc_df_died = bc_df_died[:114] # only need a few


bc_df = pd.concat([bc_df, bc_df_died], ignore_index=True)

print(bc_df.shape)
(968, 28)
In [ ]:
# Take another look at our pie chart to verify the classes are balanced.
outcomes = bc_df.groupby(['death_from_cancer'])
outcomes.count().plot(kind='pie', 
                      y='patient_id', 
                      autopct='%1.1f%%',
                      title="Quantity of Each Outcome After Oversampling")
Out[ ]:
<Axes: title={'center': 'Quantity of Each Outcome After Oversampling'}, ylabel='patient_id'>
No description has been provided for this image

Now that my outcomes are balanced I'm going to drop patient ID before moving on. This is because it is an assigned value with no use in prediction. I can use the instance index values if needed to refer to a particular instance. Also, I'll move the feature I intend to predict to be the last one in the dataframe as more of a data visualization intent than anything else.

In [ ]:
# Drop patient id
bc_df = bc_df.drop('patient_id', axis=1) 

# Move column 'death_from_cancer' to the end
bc_df = bc_df[[col for col in bc_df.columns if col != 'death_from_cancer'] + ['death_from_cancer']]

bc_df.info(verbose=True)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 968 entries, 0 to 967
Data columns (total 27 columns):
 #   Column                          Non-Null Count  Dtype  
---  ------                          --------------  -----  
 0   age_at_diagnosis                968 non-null    float64
 1   type_of_breast_surgery          968 non-null    object 
 2   cancer_type_detailed            968 non-null    object 
 3   cellularity                     968 non-null    object 
 4   chemotherapy                    968 non-null    int64  
 5   pam50_plus_claudin-low_subtype  968 non-null    object 
 6   er_status_measured_by_ihc       968 non-null    object 
 7   er_status                       968 non-null    object 
 8   neoplasm_histologic_grade       968 non-null    float64
 9   her2_status_measured_by_snp6    968 non-null    object 
 10  her2_status                     968 non-null    object 
 11  tumor_other_histologic_subtype  968 non-null    object 
 12  hormone_therapy                 968 non-null    int64  
 13  inferred_menopausal_state       968 non-null    object 
 14  integrative_cluster             968 non-null    object 
 15  primary_tumor_laterality        968 non-null    object 
 16  lymph_nodes_examined_positive   968 non-null    float64
 17  mutation_count                  968 non-null    float64
 18  nottingham_prognostic_index     968 non-null    float64
 19  oncotree_code                   968 non-null    object 
 20  overall_survival_months         968 non-null    float64
 21  pr_status                       968 non-null    object 
 22  radio_therapy                   968 non-null    int64  
 23  3-gene_classifier_subtype       968 non-null    object 
 24  tumor_size                      968 non-null    float64
 25  tumor_stage                     968 non-null    float64
 26  death_from_cancer               968 non-null    object 
dtypes: float64(8), int64(3), object(16)
memory usage: 204.3+ KB

Final Dataset Description¶

This final dataset is comprised of 968 breast cancer patients with 27 features related to their case. At this point, the data has not been either one-hot encoded or label-encoded. I'll do that as part of my FeatureSpace setup. The target data is an object datatype containing the outcome of whether or not the patient lived, making this model a binary classifier.

The outcome (target) values datasets are balanced with half of the instances relating to patients that lived, and half the instances relating to patients that died of the disease. I've removed any instances with missing values from the dataset and I have removed gene specific z-values which comprised a great deal of the original features but were not value-added for my purposes. The data is not yet scaled.

Cross-Product Feature Identification¶

To start, I'm interested in how many unique values I have in each feature:

In [ ]:
# Source: for dictionary sorting https://stackoverflow.com/questions/64885734/how-to-sort-a-dictionary-in-descending-order-according-its-value
# unique_feature_count = [feature for feature in bc_df.columns if feature != 'death_from_cancer']
unique_feature_count = [feature for feature in bc_df.columns if bc_df[feature].dtype == object]
unique_dict = {}
for feature in unique_feature_count:
    unique_vals = len(list(enumerate(bc_df[feature].unique())))
    unique_dict[feature] = unique_vals

unique_dict_sorted = sorted(unique_dict.items(), key=lambda x:x[1], reverse=True)
for item in unique_dict_sorted:
    print(f'There are {item[1]} unique values in {item[0]}')
There are 11 unique values in integrative_cluster
There are 7 unique values in pam50_plus_claudin-low_subtype
There are 7 unique values in tumor_other_histologic_subtype
There are 5 unique values in cancer_type_detailed
There are 5 unique values in oncotree_code
There are 4 unique values in her2_status_measured_by_snp6
There are 4 unique values in 3-gene_classifier_subtype
There are 3 unique values in cellularity
There are 2 unique values in type_of_breast_surgery
There are 2 unique values in er_status_measured_by_ihc
There are 2 unique values in er_status
There are 2 unique values in her2_status
There are 2 unique values in inferred_menopausal_state
There are 2 unique values in primary_tumor_laterality
There are 2 unique values in pr_status
There are 2 unique values in death_from_cancer

To determine which features to cross I'd like to understand which ones are correlated to eachother. To do that, I need to one hot encode some of my values. Because I'm going to encode later, I'll make this as a separate dataframe.

In [ ]:
# One-hot encode other object values
# I'll write a loop for this since there are several

import copy
bc_df_encoded = bc_df.copy()
# limit to categorical features
features_to_encode = [label for label in bc_df_encoded.columns if bc_df_encoded.dtypes[label] == object]

# print(features_to_encode) # debug

for feature in features_to_encode:
    tmp_df = pd.get_dummies(bc_df[feature],prefix=feature)
    bc_df_encoded = pd.concat((bc_df_encoded,tmp_df),axis=1)
    bc_df_encoded = bc_df_encoded.drop(feature, axis=1) # drop original column
In [ ]:
# Let's pull some basic correlation data to see if that will help identify features to cross product
# Check correlation of each feature to 'death_from_cancer' first
features_to_correlate = [feature for feature in bc_df_encoded.columns 
                         if feature != 'death_from_cancer' 
                         and bc_df_encoded[feature].dtype == bool]

corr_to_outcome = [bc_df_encoded[feature].corr(bc_df_encoded['death_from_cancer_Died of Disease']) 
                   for feature in features_to_correlate] 
vars_to_use = []
for feature, value in zip(features_to_correlate, corr_to_outcome):
        if value >= 0.1: # Limit just to values 0.2
                print(f'The correlation of ', feature, ' to outcome is ', round(value,3))
                vars_to_use.append(feature)
The correlation of  type_of_breast_surgery_MASTECTOMY  to outcome is  0.201
The correlation of  pam50_plus_claudin-low_subtype_Her2  to outcome is  0.138
The correlation of  pam50_plus_claudin-low_subtype_LumB  to outcome is  0.118
The correlation of  her2_status_measured_by_snp6_GAIN  to outcome is  0.106
The correlation of  her2_status_Positive  to outcome is  0.137
The correlation of  inferred_menopausal_state_Post  to outcome is  0.1
The correlation of  integrative_cluster_5  to outcome is  0.159
The correlation of  pr_status_Negative  to outcome is  0.11
The correlation of  3-gene_classifier_subtype_ER+/HER2- High Prolif  to outcome is  0.14
The correlation of  3-gene_classifier_subtype_HER2+  to outcome is  0.12
The correlation of  death_from_cancer_Died of Disease  to outcome is  1.0

No single feature appears to be significantly correlated with my target values. However it would also be useful to understand how some of these features correlate to one-another. I'll put together a heatmap of some of these values to visualize that.

In [ ]:
# Source: modified from lab_1 in-class lectures
import seaborn as sns

# plot the correlation matrix using a subset of features

cmap = sns.set(style="darkgrid") # one of the many styles to plot using

f, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(bc_df_encoded[vars_to_use].corr(), cmap=cmap, annot=True)
Out[ ]:
<Axes: >
No description has been provided for this image

Generally, I'm going to pick things that can be thought of as rules. As such, I'm looking for things that are related or appear to move as though they are related. Therefore I'm using correlation as a guiding principle. Second, I'll look for features that seem like they would interract.

Therefore, I'm going to cross the following:

  • her2_status_measured_by_snp6 and her2_status. These two items are related in that both are used to classify the cancer subtype (Source: mayoclinic.org)
  • 3-gene_classifier_subtype and integrative_cluster. "Integrative_clustering is a breast cancer classification of 10 different subgroups with distinctive molecular profiles and clininclal outcomes" (source: https://ascopubs.org/doi/abs/10.1200/JCO.2018.36.15_suppl.579)
  • 3-gene_classifier, integrative_cluster, pam50_plus is another option to consider as all are forms of classification and should move together
  • er_status and er_status_measured_by_ihc should also move together though I didn't put them in my correlation heatmap
  • Another experiment to try would be to see if I can map chemotherapy, radio_therapy, and hormone_therapy to true/false instead of 0/1 because then I could cross the features (I tried to cross these as integer values which I quikcly learned you can't do)

Metric Selection and Reasoning¶

For this model I'll be using F-measure for assessing performance

The model will be designed as a binary classifier of the probable outcome of breast cancer based on the inputs. So the metrics of primary interest are true positives (the patient will likely die due to disesase), and false negatives (the patient died of disesase but it wasn't predicted).

I'll take a moment to discuss false positives as their impact can be viewed differently depending on perspective. A false positive for this model would be a prediction that the patient would die from disease but ends up being incorrect. This presents an obviously difficult situation for the patient due to the psychological impact of such a diagnosis. The counterbalance of this however, is that the patients outcome is much better. Despite this potential upside, the patient may make life altering decisions based on a diagnosis this model provides. Therefore I have to treat false positives with almost equal importance as true positives and false negatives.

To choose a metric that meets these requirements, I need to assess the measurements used in calculating the metrics. I would like to select a metric that emphasizes:

  • True Positives
  • False Negatives
  • False Positives

Precision Uses False Positives and False Positives. It does not weigh false negatives. Since I believe false negatives will be important to my model's performance, this is not a viable option. Recall Uses True Positives and False Negatives. Again it addresses two of my areas of interest but not all three as False Positives are missing.

F-Measure combines precision and recall with equal balances between the two. This is an ideal measure to use for my purposes as it addresses the three areas of interest with equal weightings. I'll plan on using F-measure in assessing my model's performance.

Another option to consider would be using F_beta which allows me to decrease the impact of Recall on the calculation. However, knowing that recall uses True Positives and False Negatives which are both of high importance to me, I don't want to decrement the weighting of that particular metric in this circumstance.

Methods for Dividing and Testing the Data¶

For this model I'll be using K-fold for Test/Train Split

To assess which method for my testing and training data is best I need to understand some basic information about how I've set up my data. I purposefully oversampled my data for this lab such that the results of my binary output are an even 50/50 split. Additionally, having 968 instances and 27 features, the size of the dataset is in the small to medium range.

The options I have for selecting how to split my testing and trainig data include holdout, random subsampling, and KFold or Stratified KFold. Because my dataset is of the small to medium size range, holdout and random subsampling are likely not required. Random subsampling would be more applicable in a very large datset which makes training a model on all the data an inefficient process. Holdout is a viable option, however KFold provides a more thorough understanding of how my model is performing on the data.

K-fold will help ensure I have an evenly divided test data-set, and any trends that might appear in the data are mitigated. I'll also be using shuffle to further address this. Stratified K-fold would be appropriate if my dataset was unbalanced in its results. Even without oversampling I show above, this dataset was relatively even. Therefore I should be fine using K-fold without employing the stratified technique. If I want to have an alternative option I will consider using holdout as that is the next most viable selection.

In [ ]:
# Setup my feature labels with appropriate variables needed later
import numpy as np

# create a tensorflow dataset, for ease of use later
batch_size = 44 # 44 divides evenly into my total instances, 968

# Map the classification groups to integers
# This could have been done in FeatureSpace, but doing here as a matter of preference
survivability_dict = {'Living':0, 'Died of Disease':1}
bc_df['death_from_cancer'] = bc_df['death_from_cancer'].map(survivability_dict)

categorical_headers = [label for label in bc_df.columns if bc_df.dtypes[label] == object]
int_headers = ['chemotherapy','radio_therapy','hormone_therapy']
numeric_headers = [label for label in bc_df.columns if bc_df.dtypes[label] == float] + int_headers
In [ ]:
# Perform the test/train split of the data
# Source: https://stackoverflow.com/questions/45115964/separate-pandas-dataframe-using-sklearns-kfold
# Source: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html
# Source: https://machinelearningmastery.com/k-fold-cross-validation/
from sklearn.model_selection import KFold

kf = KFold(n_splits=6, shuffle=True, random_state=1) # 6 splits resulted in ~80/20 split
result = next(kf.split(bc_df), None) # returns indices of test/train instances

bc_df_train = bc_df.iloc[result[0]]
bc_df_test =  bc_df.iloc[result[1]]

print(bc_df_train.shape)
print(bc_df_test.shape)
(806, 27)
(162, 27)

2. Modeling¶

As previously mentioned, I'm going to be setting up all of my models using FeatureSpaces. This will give me an easy way to re-use configurations as necessary. I'm going to be creating 3 models in this first section, all of which are deep and wide representations.

In [ ]:
from sklearn import metrics as mt
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(keras.__version__)
2.12.0
2.12.0
In [ ]:
from tensorflow.keras.layers import Dense, Activation, Input
from tensorflow.keras.layers import Embedding, Concatenate, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.utils import FeatureSpace
In [ ]:
# Adding a function to create a tensorflow dataset from dataframe
# Source: modified from in class lecture/notebook to align with my dataset

def create_dataset_from_dataframe(df_input):

    df = df_input.copy()
    labels = df['death_from_cancer']

    df = {key: value.values[:,np.newaxis] for key, value in df_input[categorical_headers+numeric_headers].items()}
    # print(df) # debug
    # create the Dataset here
    ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))
    
    # now enable batching and prefetching
    ds = ds.batch(batch_size)
    ds = ds.prefetch(batch_size)
    
    return ds

ds_train = create_dataset_from_dataframe(bc_df_train)
ds_test = create_dataset_from_dataframe(bc_df_test)
In [ ]:
# Adding a function to create embeddings from the tensors
# Source: modified from in class lecture/notebook to align with my dataset
from tensorflow.keras.layers import Embedding, Flatten

def setup_embedding_from_categorical(feature_space, col_name):
    # what the maximum integer value for this variable?
    # which is the same as the number of categories
    N = len(feature_space.preprocessors[col_name].get_vocabulary())
    
    # get the output from the feature space, which is input to embedding
    x = feature_space.preprocessors[col_name].output
    
    # now use an embedding to deal with integers from feature space
    x = Embedding(input_dim=N, 
                  output_dim=int(np.sqrt(N)), 
                  input_length=1, name=col_name+'_embed')(x)
    
    x = Flatten()(x) # get rid of that pesky extra dimension (for time of embedding)
    
    return x # return the tensor here 
In [ ]:
#Source: Modified from in-class lecture
def setup_embedding_from_crossing(feature_space, col_name):
    # what the maximum integer value for this variable?
    
    # get the size of the feature
    N = feature_space.crossers[col_name].num_bins
    x = feature_space.crossers[col_name].output
    
    
    # now use an embedding to deal with integers as if they were one hot encoded
    x = Embedding(input_dim=N, 
                  output_dim=int(np.sqrt(N)), 
                  input_length=1, name=col_name+'_embed')(x)
    
    x = Flatten()(x) # get rid of that pesky extra dimension (for time of embedding)
    
    return x
In [ ]:
# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
# Setup functions to allow for F1 calculation
# Note I found this functionality to be depricated in my version of Keras, so it required a manual implementation
from keras import backend as K

def recall_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def precision_m(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def f1_m(y_true, y_pred):
    precision = precision_m(y_true, y_pred)
    recall = recall_m(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

Now I can setup my models for Keras.

Model 1 of 3¶

I'll start all of my models by either setting up or re-using a FeatureSpace.

In [ ]:
# Source: Modified from in-class lecture to match my dataset
from tensorflow.keras.utils import FeatureSpace

feature_space_1 = FeatureSpace(
    features={
        # Categorical feature encoded as string
        "type_of_breast_surgery": FeatureSpace.string_categorical(num_oov_indices=0),
        "cancer_type_detailed": FeatureSpace.string_categorical(num_oov_indices=0),
        "cellularity": FeatureSpace.string_categorical(num_oov_indices=0),
        "pam50_plus_claudin-low_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        "er_status_measured_by_ihc": FeatureSpace.string_categorical(num_oov_indices=0),
        "er_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "her2_status_measured_by_snp6": FeatureSpace.string_categorical(num_oov_indices=0),
        "her2_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        # "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        "inferred_menopausal_state": FeatureSpace.string_categorical(num_oov_indices=0),
        "integrative_cluster": FeatureSpace.string_categorical(num_oov_indices=0),
        "primary_tumor_laterality": FeatureSpace.string_categorical(num_oov_indices=0),
        "oncotree_code": FeatureSpace.string_categorical(num_oov_indices=0),
        "pr_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "3-gene_classifier_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        # "chemotherapy": FeatureSpace.string_categorical(num_oov_indices=0),
        # "hormone_therapy": FeatureSpace.string_categorical(num_oov_indices=0),
        # "radio_therapy": FeatureSpace.string_categorical(num_oov_indices=0),
        
        # Numerical features to normalize (normalization will be learned)
        # learns the mean, variance, and if to invert
        "chemotherapy": FeatureSpace.float_normalized(),
        "hormone_therapy": FeatureSpace.float_normalized(),
        "radio_therapy": FeatureSpace.float_normalized(),
        "age_at_diagnosis": FeatureSpace.float_normalized(),
        "neoplasm_histologic_grade": FeatureSpace.float_normalized(),
        "lymph_nodes_examined_positive": FeatureSpace.float_normalized(),
        "mutation_count": FeatureSpace.float_normalized(),
        "nottingham_prognostic_index": FeatureSpace.float_normalized(),
        "overall_survival_months": FeatureSpace.float_normalized(),
        "tumor_size": FeatureSpace.float_normalized(),
        "tumor_stage": FeatureSpace.float_normalized(),
    },
    # Specify feature cross with a custom crossing dim
    crosses=[
        FeatureSpace.cross(
            feature_names=('her2_status_measured_by_snp6','her2_status'),
            crossing_dim=4*2),
        FeatureSpace.cross(
            feature_names=('3-gene_classifier_subtype', 'integrative_cluster'),
            crossing_dim=4*11),
        FeatureSpace.cross(
            feature_names=('er_status', 'er_status_measured_by_ihc'),
            crossing_dim=2*2),    
        ],
    output_mode="concat", 
)


# now that we have specified the preprocessing, let's run it on the data

# create a version of the dataset that can be iterated without labels
train_ds_with_no_labels = ds_train.map(lambda x, _: x)  
feature_space_1.adapt(train_ds_with_no_labels) # inititalize the feature map to this data
# the adapt function allows the model to learn one-hot encoding sizes

# I won't be using the pre-processed portion in my models, but I'll need it later
# now define a preprocessing operation that returns the processed features
preprocessed_ds_train = ds_train.map(lambda x, y: (feature_space_1(x), y), 
                                     num_parallel_calls=tf.data.AUTOTUNE)
# run it so that we can use the pre-processed data
preprocessed_ds_train = preprocessed_ds_train.prefetch(tf.data.AUTOTUNE)

# do the same for the test set
preprocessed_ds_test = ds_test.map(lambda x, y: (feature_space_1(x), y), num_parallel_calls=tf.data.AUTOTUNE)
preprocessed_ds_test = preprocessed_ds_test.prefetch(tf.data.AUTOTUNE)
In [ ]:
# Source: Modified from in-class lecture to match my dataset
dict_inputs = feature_space_1.get_inputs() # need to use unprocessed features here, to gain access to each output

# we need to create separate lists for each branch
crossed_outputs = []

# for each crossed variable, make an embedding
for col in feature_space_1.crossers.keys():
    
    x = setup_embedding_from_crossing(feature_space_1, col)
    
    # save these outputs in list to concatenate later
    crossed_outputs.append(x)
    

# now concatenate the outputs and add a fully connected layer
wide_branch = Concatenate(name='wide_concat')(crossed_outputs)

# reset this input branch
all_deep_branch_outputs = []

# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
    x = feature_space_1.preprocessors[col].output
    x = tf.cast(x,float) # cast an integer as a float here
    all_deep_branch_outputs.append(x)
    
# for each categorical variable
for col in categorical_headers:
    
    # get the output tensor from ebedding layer
    x = setup_embedding_from_categorical(feature_space_1, col)
    
    # save these outputs in list to concatenate later
    all_deep_branch_outputs.append(x)


# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=10,activation='relu', name='deep3')(deep_branch)
    
# merge the deep and wide branch
final_branch = Concatenate(name='concat_deep_wide')([deep_branch, wide_branch])
final_branch = Dense(units=1,activation='sigmoid',
                     name='combined')(final_branch)

training_model_1 = keras.Model(inputs=dict_inputs, outputs=final_branch)

# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_1.compile(
    optimizer="adam", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)

training_model_1.summary()

plot_model(
    training_model_1, to_file='model.png', show_shapes=True, show_layer_names=True,
    rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_35"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 type_of_breast_surgery (InputL  [(None, 1)]         0           []                               
 ayer)                                                                                            
                                                                                                  
 cancer_type_detailed (InputLay  [(None, 1)]         0           []                               
 er)                                                                                              
                                                                                                  
 cellularity (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 pam50_plus_claudin-low_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 er_status_measured_by_ihc (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 er_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 her2_status_measured_by_snp6 (  [(None, 1)]         0           []                               
 InputLayer)                                                                                      
                                                                                                  
 her2_status (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 tumor_other_histologic_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 inferred_menopausal_state (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 integrative_cluster (InputLaye  [(None, 1)]         0           []                               
 r)                                                                                               
                                                                                                  
 primary_tumor_laterality (Inpu  [(None, 1)]         0           []                               
 tLayer)                                                                                          
                                                                                                  
 oncotree_code (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 pr_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 3-gene_classifier_subtype (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 age_at_diagnosis (InputLayer)  [(None, 1)]          0           []                               
                                                                                                  
 neoplasm_histologic_grade (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 lymph_nodes_examined_positive   [(None, 1)]         0           []                               
 (InputLayer)                                                                                     
                                                                                                  
 mutation_count (InputLayer)    [(None, 1)]          0           []                               
                                                                                                  
 nottingham_prognostic_index (I  [(None, 1)]         0           []                               
 nputLayer)                                                                                       
                                                                                                  
 overall_survival_months (Input  [(None, 1)]         0           []                               
 Layer)                                                                                           
                                                                                                  
 tumor_size (InputLayer)        [(None, 1)]          0           []                               
                                                                                                  
 tumor_stage (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 chemotherapy (InputLayer)      [(None, 1)]          0           []                               
                                                                                                  
 radio_therapy (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 hormone_therapy (InputLayer)   [(None, 1)]          0           []                               
                                                                                                  
 string_categorical_424_preproc  (None, 1)           0           ['type_of_breast_surgery[0][0]'] 
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_425_preproc  (None, 1)           0           ['cancer_type_detailed[0][0]']   
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_426_preproc  (None, 1)           0           ['cellularity[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_427_preproc  (None, 1)           0           ['pam50_plus_claudin-low_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_428_preproc  (None, 1)           0           ['er_status_measured_by_ihc[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_429_preproc  (None, 1)           0           ['er_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_430_preproc  (None, 1)           0           ['her2_status_measured_by_snp6[0]
 essor (StringLookup)                                            [0]']                            
                                                                                                  
 string_categorical_431_preproc  (None, 1)           0           ['her2_status[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_432_preproc  (None, 1)           0           ['tumor_other_histologic_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_433_preproc  (None, 1)           0           ['inferred_menopausal_state[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_434_preproc  (None, 1)           0           ['integrative_cluster[0][0]']    
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_435_preproc  (None, 1)           0           ['primary_tumor_laterality[0][0]'
 essor (StringLookup)                                            ]                                
                                                                                                  
 string_categorical_436_preproc  (None, 1)           0           ['oncotree_code[0][0]']          
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_437_preproc  (None, 1)           0           ['pr_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_438_preproc  (None, 1)           0           ['3-gene_classifier_subtype[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 float_normalized_306_preproces  (None, 1)           3           ['age_at_diagnosis[0][0]']       
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_307_preproces  (None, 1)           3           ['neoplasm_histologic_grade[0][0]
 sor (Normalization)                                             ']                               
                                                                                                  
 float_normalized_308_preproces  (None, 1)           3           ['lymph_nodes_examined_positive[0
 sor (Normalization)                                             ][0]']                           
                                                                                                  
 float_normalized_309_preproces  (None, 1)           3           ['mutation_count[0][0]']         
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_310_preproces  (None, 1)           3           ['nottingham_prognostic_index[0][
 sor (Normalization)                                             0]']                             
                                                                                                  
 float_normalized_311_preproces  (None, 1)           3           ['overall_survival_months[0][0]']
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_312_preproces  (None, 1)           3           ['tumor_size[0][0]']             
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_313_preproces  (None, 1)           3           ['tumor_stage[0][0]']            
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_303_preproces  (None, 1)           3           ['chemotherapy[0][0]']           
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_305_preproces  (None, 1)           3           ['radio_therapy[0][0]']          
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_304_preproces  (None, 1)           3           ['hormone_therapy[0][0]']        
 sor (Normalization)                                                                              
                                                                                                  
 type_of_breast_surgery_embed (  (None, 1, 1)        2           ['string_categorical_424_preproce
 Embedding)                                                      ssor[0][0]']                     
                                                                                                  
 cancer_type_detailed_embed (Em  (None, 1, 2)        10          ['string_categorical_425_preproce
 bedding)                                                        ssor[0][0]']                     
                                                                                                  
 cellularity_embed (Embedding)  (None, 1, 1)         3           ['string_categorical_426_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 pam50_plus_claudin-low_subtype  (None, 1, 2)        14          ['string_categorical_427_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 er_status_measured_by_ihc_embe  (None, 1, 1)        2           ['string_categorical_428_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 er_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_429_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 her2_status_measured_by_snp6_e  (None, 1, 2)        8           ['string_categorical_430_preproce
 mbed (Embedding)                                                ssor[0][0]']                     
                                                                                                  
 her2_status_embed (Embedding)  (None, 1, 1)         2           ['string_categorical_431_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 tumor_other_histologic_subtype  (None, 1, 2)        14          ['string_categorical_432_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 inferred_menopausal_state_embe  (None, 1, 1)        2           ['string_categorical_433_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 integrative_cluster_embed (Emb  (None, 1, 3)        33          ['string_categorical_434_preproce
 edding)                                                         ssor[0][0]']                     
                                                                                                  
 primary_tumor_laterality_embed  (None, 1, 1)        2           ['string_categorical_435_preproce
  (Embedding)                                                    ssor[0][0]']                     
                                                                                                  
 oncotree_code_embed (Embedding  (None, 1, 2)        10          ['string_categorical_436_preproce
 )                                                               ssor[0][0]']                     
                                                                                                  
 pr_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_437_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_embe  (None, 1, 2)        8           ['string_categorical_438_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 tf.cast_385 (TFOpLambda)       (None, 1)            0           ['float_normalized_306_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_386 (TFOpLambda)       (None, 1)            0           ['float_normalized_307_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_387 (TFOpLambda)       (None, 1)            0           ['float_normalized_308_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_388 (TFOpLambda)       (None, 1)            0           ['float_normalized_309_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_389 (TFOpLambda)       (None, 1)            0           ['float_normalized_310_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_390 (TFOpLambda)       (None, 1)            0           ['float_normalized_311_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_391 (TFOpLambda)       (None, 1)            0           ['float_normalized_312_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_392 (TFOpLambda)       (None, 1)            0           ['float_normalized_313_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_393 (TFOpLambda)       (None, 1)            0           ['float_normalized_303_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_394 (TFOpLambda)       (None, 1)            0           ['float_normalized_305_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_395 (TFOpLambda)       (None, 1)            0           ['float_normalized_304_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 flatten_620 (Flatten)          (None, 1)            0           ['type_of_breast_surgery_embed[0]
                                                                 [0]']                            
                                                                                                  
 flatten_621 (Flatten)          (None, 2)            0           ['cancer_type_detailed_embed[0][0
                                                                 ]']                              
                                                                                                  
 flatten_622 (Flatten)          (None, 1)            0           ['cellularity_embed[0][0]']      
                                                                                                  
 flatten_623 (Flatten)          (None, 2)            0           ['pam50_plus_claudin-low_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_624 (Flatten)          (None, 1)            0           ['er_status_measured_by_ihc_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_625 (Flatten)          (None, 1)            0           ['er_status_embed[0][0]']        
                                                                                                  
 flatten_626 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_em
                                                                 bed[0][0]']                      
                                                                                                  
 flatten_627 (Flatten)          (None, 1)            0           ['her2_status_embed[0][0]']      
                                                                                                  
 flatten_628 (Flatten)          (None, 2)            0           ['tumor_other_histologic_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_629 (Flatten)          (None, 1)            0           ['inferred_menopausal_state_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_630 (Flatten)          (None, 3)            0           ['integrative_cluster_embed[0][0]
                                                                 ']                               
                                                                                                  
 flatten_631 (Flatten)          (None, 1)            0           ['primary_tumor_laterality_embed[
                                                                 0][0]']                          
                                                                                                  
 flatten_632 (Flatten)          (None, 2)            0           ['oncotree_code_embed[0][0]']    
                                                                                                  
 flatten_633 (Flatten)          (None, 1)            0           ['pr_status_embed[0][0]']        
                                                                                                  
 flatten_634 (Flatten)          (None, 2)            0           ['3-gene_classifier_subtype_embed
                                                                 [0][0]']                         
                                                                                                  
 embed_concat (Concatenate)     (None, 34)           0           ['tf.cast_385[0][0]',            
                                                                  'tf.cast_386[0][0]',            
                                                                  'tf.cast_387[0][0]',            
                                                                  'tf.cast_388[0][0]',            
                                                                  'tf.cast_389[0][0]',            
                                                                  'tf.cast_390[0][0]',            
                                                                  'tf.cast_391[0][0]',            
                                                                  'tf.cast_392[0][0]',            
                                                                  'tf.cast_393[0][0]',            
                                                                  'tf.cast_394[0][0]',            
                                                                  'tf.cast_395[0][0]',            
                                                                  'flatten_620[0][0]',            
                                                                  'flatten_621[0][0]',            
                                                                  'flatten_622[0][0]',            
                                                                  'flatten_623[0][0]',            
                                                                  'flatten_624[0][0]',            
                                                                  'flatten_625[0][0]',            
                                                                  'flatten_626[0][0]',            
                                                                  'flatten_627[0][0]',            
                                                                  'flatten_628[0][0]',            
                                                                  'flatten_629[0][0]',            
                                                                  'flatten_630[0][0]',            
                                                                  'flatten_631[0][0]',            
                                                                  'flatten_632[0][0]',            
                                                                  'flatten_633[0][0]',            
                                                                  'flatten_634[0][0]']            
                                                                                                  
 her2_status_measured_by_snp6_X  (None, 1)           0           ['string_categorical_430_preproce
 _her2_status (HashedCrossing)                                   ssor[0][0]',                     
                                                                  'string_categorical_431_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_X_in  (None, 1)           0           ['string_categorical_438_preproce
 tegrative_cluster (HashedCross                                  ssor[0][0]',                     
 ing)                                                             'string_categorical_434_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 er_status_X_er_status_measured  (None, 1)           0           ['string_categorical_429_preproce
 _by_ihc (HashedCrossing)                                        ssor[0][0]',                     
                                                                  'string_categorical_428_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 deep1 (Dense)                  (None, 34)           1190        ['embed_concat[0][0]']           
                                                                                                  
 her2_status_measured_by_snp6_X  (None, 1, 2)        16          ['her2_status_measured_by_snp6_X_
 _her2_status_embed (Embedding)                                  her2_status[0][0]']              
                                                                                                  
 3-gene_classifier_subtype_X_in  (None, 1, 6)        264         ['3-gene_classifier_subtype_X_int
 tegrative_cluster_embed (Embed                                  egrative_cluster[0][0]']         
 ding)                                                                                            
                                                                                                  
 er_status_X_er_status_measured  (None, 1, 2)        8           ['er_status_X_er_status_measured_
 _by_ihc_embed (Embedding)                                       by_ihc[0][0]']                   
                                                                                                  
 deep2 (Dense)                  (None, 17)           595         ['deep1[0][0]']                  
                                                                                                  
 flatten_617 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_X_
                                                                 her2_status_embed[0][0]']        
                                                                                                  
 flatten_618 (Flatten)          (None, 6)            0           ['3-gene_classifier_subtype_X_int
                                                                 egrative_cluster_embed[0][0]']   
                                                                                                  
 flatten_619 (Flatten)          (None, 2)            0           ['er_status_X_er_status_measured_
                                                                 by_ihc_embed[0][0]']             
                                                                                                  
 deep3 (Dense)                  (None, 10)           180         ['deep2[0][0]']                  
                                                                                                  
 wide_concat (Concatenate)      (None, 10)           0           ['flatten_617[0][0]',            
                                                                  'flatten_618[0][0]',            
                                                                  'flatten_619[0][0]']            
                                                                                                  
 concat_deep_wide (Concatenate)  (None, 20)          0           ['deep3[0][0]',                  
                                                                  'wide_concat[0][0]']            
                                                                                                  
 combined (Dense)               (None, 1)            21          ['concat_deep_wide[0][0]']       
                                                                                                  
==================================================================================================
Total params: 2,421
Trainable params: 2,388
Non-trainable params: 33
__________________________________________________________________________________________________
Out[ ]:
No description has been provided for this image
In [ ]:
# Train the model
history_1 = training_model_1.fit(
    ds_train, epochs=25, validation_data=ds_test, verbose=2
)
Epoch 1/25
19/19 - 4s - loss: 0.6937 - acc: 0.5211 - f1_m: 0.1378 - precision_m: 0.5921 - recall_m: 0.0799 - val_loss: 0.6872 - val_acc: 0.5247 - val_f1_m: 0.2440 - val_precision_m: 0.6726 - val_recall_m: 0.1592 - 4s/epoch - 236ms/step
Epoch 2/25
19/19 - 0s - loss: 0.6610 - acc: 0.6563 - f1_m: 0.4916 - precision_m: 0.8188 - recall_m: 0.3629 - val_loss: 0.6650 - val_acc: 0.6481 - val_f1_m: 0.5436 - val_precision_m: 0.7738 - val_recall_m: 0.4575 - 73ms/epoch - 4ms/step
Epoch 3/25
19/19 - 0s - loss: 0.6370 - acc: 0.7072 - f1_m: 0.6399 - precision_m: 0.7814 - recall_m: 0.5686 - val_loss: 0.6426 - val_acc: 0.6914 - val_f1_m: 0.6419 - val_precision_m: 0.7453 - val_recall_m: 0.5908 - 71ms/epoch - 4ms/step
Epoch 4/25
19/19 - 0s - loss: 0.6084 - acc: 0.7233 - f1_m: 0.6792 - precision_m: 0.7612 - recall_m: 0.6501 - val_loss: 0.6136 - val_acc: 0.7099 - val_f1_m: 0.6848 - val_precision_m: 0.7398 - val_recall_m: 0.6750 - 74ms/epoch - 4ms/step
Epoch 5/25
19/19 - 0s - loss: 0.5751 - acc: 0.7531 - f1_m: 0.7241 - precision_m: 0.7796 - recall_m: 0.7207 - val_loss: 0.5801 - val_acc: 0.7160 - val_f1_m: 0.6959 - val_precision_m: 0.7464 - val_recall_m: 0.6958 - 72ms/epoch - 4ms/step
Epoch 6/25
19/19 - 0s - loss: 0.5426 - acc: 0.7568 - f1_m: 0.7268 - precision_m: 0.7738 - recall_m: 0.7275 - val_loss: 0.5459 - val_acc: 0.7469 - val_f1_m: 0.7314 - val_precision_m: 0.7767 - val_recall_m: 0.7258 - 76ms/epoch - 4ms/step
Epoch 7/25
19/19 - 0s - loss: 0.5143 - acc: 0.7655 - f1_m: 0.7351 - precision_m: 0.7764 - recall_m: 0.7379 - val_loss: 0.5161 - val_acc: 0.7531 - val_f1_m: 0.7438 - val_precision_m: 0.7765 - val_recall_m: 0.7567 - 109ms/epoch - 6ms/step
Epoch 8/25
19/19 - 0s - loss: 0.4910 - acc: 0.7742 - f1_m: 0.7469 - precision_m: 0.7830 - recall_m: 0.7535 - val_loss: 0.4912 - val_acc: 0.7716 - val_f1_m: 0.7688 - val_precision_m: 0.7714 - val_recall_m: 0.7992 - 72ms/epoch - 4ms/step
Epoch 9/25
19/19 - 0s - loss: 0.4716 - acc: 0.7816 - f1_m: 0.7618 - precision_m: 0.7834 - recall_m: 0.7772 - val_loss: 0.4714 - val_acc: 0.7963 - val_f1_m: 0.7937 - val_precision_m: 0.8052 - val_recall_m: 0.8200 - 90ms/epoch - 5ms/step
Epoch 10/25
19/19 - 0s - loss: 0.4562 - acc: 0.7965 - f1_m: 0.7793 - precision_m: 0.7951 - recall_m: 0.7942 - val_loss: 0.4565 - val_acc: 0.8148 - val_f1_m: 0.8104 - val_precision_m: 0.8170 - val_recall_m: 0.8400 - 77ms/epoch - 4ms/step
Epoch 11/25
19/19 - 0s - loss: 0.4438 - acc: 0.8040 - f1_m: 0.7893 - precision_m: 0.7950 - recall_m: 0.8138 - val_loss: 0.4461 - val_acc: 0.8148 - val_f1_m: 0.8104 - val_precision_m: 0.8170 - val_recall_m: 0.8400 - 69ms/epoch - 4ms/step
Epoch 12/25
19/19 - 0s - loss: 0.4342 - acc: 0.8040 - f1_m: 0.7893 - precision_m: 0.7951 - recall_m: 0.8140 - val_loss: 0.4388 - val_acc: 0.8148 - val_f1_m: 0.8104 - val_precision_m: 0.8222 - val_recall_m: 0.8400 - 71ms/epoch - 4ms/step
Epoch 13/25
19/19 - 0s - loss: 0.4268 - acc: 0.8065 - f1_m: 0.7906 - precision_m: 0.7976 - recall_m: 0.8120 - val_loss: 0.4338 - val_acc: 0.8148 - val_f1_m: 0.8104 - val_precision_m: 0.8222 - val_recall_m: 0.8400 - 74ms/epoch - 4ms/step
Epoch 14/25
19/19 - 0s - loss: 0.4203 - acc: 0.8077 - f1_m: 0.7915 - precision_m: 0.7980 - recall_m: 0.8120 - val_loss: 0.4302 - val_acc: 0.8210 - val_f1_m: 0.8157 - val_precision_m: 0.8287 - val_recall_m: 0.8400 - 77ms/epoch - 4ms/step
Epoch 15/25
19/19 - 0s - loss: 0.4142 - acc: 0.8089 - f1_m: 0.7926 - precision_m: 0.7997 - recall_m: 0.8115 - val_loss: 0.4283 - val_acc: 0.8272 - val_f1_m: 0.8214 - val_precision_m: 0.8359 - val_recall_m: 0.8400 - 71ms/epoch - 4ms/step
Epoch 16/25
19/19 - 0s - loss: 0.4089 - acc: 0.8139 - f1_m: 0.7970 - precision_m: 0.8023 - recall_m: 0.8157 - val_loss: 0.4272 - val_acc: 0.8272 - val_f1_m: 0.8214 - val_precision_m: 0.8359 - val_recall_m: 0.8400 - 73ms/epoch - 4ms/step
Epoch 17/25
19/19 - 0s - loss: 0.4045 - acc: 0.8164 - f1_m: 0.8005 - precision_m: 0.8061 - recall_m: 0.8193 - val_loss: 0.4266 - val_acc: 0.8272 - val_f1_m: 0.8214 - val_precision_m: 0.8359 - val_recall_m: 0.8400 - 72ms/epoch - 4ms/step
Epoch 18/25
19/19 - 0s - loss: 0.4001 - acc: 0.8176 - f1_m: 0.8030 - precision_m: 0.8072 - recall_m: 0.8235 - val_loss: 0.4264 - val_acc: 0.8210 - val_f1_m: 0.8146 - val_precision_m: 0.8345 - val_recall_m: 0.8300 - 75ms/epoch - 4ms/step
Epoch 19/25
19/19 - 0s - loss: 0.3957 - acc: 0.8213 - f1_m: 0.8063 - precision_m: 0.8115 - recall_m: 0.8235 - val_loss: 0.4265 - val_acc: 0.8210 - val_f1_m: 0.8146 - val_precision_m: 0.8345 - val_recall_m: 0.8300 - 72ms/epoch - 4ms/step
Epoch 20/25
19/19 - 0s - loss: 0.3919 - acc: 0.8213 - f1_m: 0.8063 - precision_m: 0.8115 - recall_m: 0.8235 - val_loss: 0.4266 - val_acc: 0.8210 - val_f1_m: 0.8146 - val_precision_m: 0.8345 - val_recall_m: 0.8300 - 73ms/epoch - 4ms/step
Epoch 21/25
19/19 - 0s - loss: 0.3878 - acc: 0.8251 - f1_m: 0.8097 - precision_m: 0.8150 - recall_m: 0.8261 - val_loss: 0.4267 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 71ms/epoch - 4ms/step
Epoch 22/25
19/19 - 0s - loss: 0.3840 - acc: 0.8288 - f1_m: 0.8131 - precision_m: 0.8187 - recall_m: 0.8289 - val_loss: 0.4269 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 78ms/epoch - 4ms/step
Epoch 23/25
19/19 - 0s - loss: 0.3805 - acc: 0.8300 - f1_m: 0.8161 - precision_m: 0.8186 - recall_m: 0.8365 - val_loss: 0.4273 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 71ms/epoch - 4ms/step
Epoch 24/25
19/19 - 0s - loss: 0.3766 - acc: 0.8325 - f1_m: 0.8178 - precision_m: 0.8224 - recall_m: 0.8340 - val_loss: 0.4277 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 87ms/epoch - 5ms/step
Epoch 25/25
19/19 - 0s - loss: 0.3734 - acc: 0.8325 - f1_m: 0.8180 - precision_m: 0.8233 - recall_m: 0.8353 - val_loss: 0.4279 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 78ms/epoch - 4ms/step
In [ ]:
# Print plots of metrics
from matplotlib import pyplot as plt

%matplotlib inline

plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_1.history['f1_m'])

plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_1.history['val_f1_m'])
plt.title('Validation')

plt.subplot(2,2,3)
plt.plot(history_1.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')

plt.subplot(2,2,4)
plt.plot(history_1.history['val_loss'])
plt.xlabel('epochs')
Out[ ]:
Text(0.5, 0, 'epochs')
No description has been provided for this image

I see convergence in this model at around 20-25 epochs. F1 score is good on validation and training data as well. Ideally I'd like to see my training loss get lower than this as I want it to be as low as possible. Next I'll check the confusion matrix to see my ratio of True Positives, True Negatives, False Positives, and False Negatives.

In [ ]:
#  Vizualize some metrics associated with this model

# Source: Modified from in-class lecture
# Use the sklearn metrics here, if you want to
from sklearn import metrics as mt

y_test = tf.concat([y for x, y in ds_test], axis=0)
y_test = y_test.numpy()

# now lets see how well the model performed
yhat_proba_1 = training_model_1.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions 
yhat_1 = np.round(yhat_proba_1.squeeze()) # round to get binary class

conf_mat_1 = mt.confusion_matrix(y_test, yhat_1)

print(conf_mat_1)
print(mt.classification_report(y_test,yhat_1))

# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309.  VitalBook file.
# Create pandas dataframe
conf_df_1 = pd.DataFrame(conf_mat_1, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())

# Create heatmap
sns.heatmap(conf_df_1, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 1s 3ms/step
[[66 14]
 [16 66]]
              precision    recall  f1-score   support

           0       0.80      0.82      0.81        80
           1       0.82      0.80      0.81        82

    accuracy                           0.81       162
   macro avg       0.81      0.81      0.81       162
weighted avg       0.82      0.81      0.81       162

No description has been provided for this image

Not a bad result. My True Positives and True Negatives are relatively high compared with False Positives and False Negatives. However I'd really like to see fewer false positives. As I stated above they can be bad for patient well-being due to the decisions they may drive. We'll see if we can improve on this in subsequent models.

Model 2 of 3¶

This model changes some of the cross features to see if it impacts results. In model 2, I'm going to see if I can try some of the other feature space arrangements discussed above. So I'll use the following:

  • her2_status_measured_by_snp6 and her2_status (old)
  • 3-gene_classifier, integrative_cluster, pam50_plus (new)
  • er_status and er_status_measured_by_ihc (old)
In [ ]:
# Source: Modified from in-class lecture to match my dataset
from tensorflow.keras.utils import FeatureSpace

feature_space_2 = FeatureSpace(
    features={
        # Categorical feature encoded as string
        "type_of_breast_surgery": FeatureSpace.string_categorical(num_oov_indices=0),
        "cancer_type_detailed": FeatureSpace.string_categorical(num_oov_indices=0),
        "cellularity": FeatureSpace.string_categorical(num_oov_indices=0),
        "pam50_plus_claudin-low_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        "er_status_measured_by_ihc": FeatureSpace.string_categorical(num_oov_indices=0),
        "er_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "her2_status_measured_by_snp6": FeatureSpace.string_categorical(num_oov_indices=0),
        "her2_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        # "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        "inferred_menopausal_state": FeatureSpace.string_categorical(num_oov_indices=0),
        "integrative_cluster": FeatureSpace.string_categorical(num_oov_indices=0),
        "primary_tumor_laterality": FeatureSpace.string_categorical(num_oov_indices=0),
        "oncotree_code": FeatureSpace.string_categorical(num_oov_indices=0),
        "pr_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "3-gene_classifier_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        
        # Numerical features to normalize (normalization will be learned)
        # learns the mean, variance, and if to invert
        "chemotherapy": FeatureSpace.float_normalized(),
        "hormone_therapy": FeatureSpace.float_normalized(),
        "radio_therapy": FeatureSpace.float_normalized(),
        "age_at_diagnosis": FeatureSpace.float_normalized(),
        "neoplasm_histologic_grade": FeatureSpace.float_normalized(),
        "lymph_nodes_examined_positive": FeatureSpace.float_normalized(),
        "mutation_count": FeatureSpace.float_normalized(),
        "nottingham_prognostic_index": FeatureSpace.float_normalized(),
        "overall_survival_months": FeatureSpace.float_normalized(),
        "tumor_size": FeatureSpace.float_normalized(),
        "tumor_stage": FeatureSpace.float_normalized(),
    },
    # Specify feature cross with a custom crossing dim
    crosses=[
        FeatureSpace.cross(
            feature_names=('her2_status_measured_by_snp6','her2_status'),
            crossing_dim=4*2),
        FeatureSpace.cross(
            feature_names=('3-gene_classifier_subtype', 'integrative_cluster', 'pam50_plus_claudin-low_subtype'),
            crossing_dim=4*11*7),
        FeatureSpace.cross(
            feature_names=('er_status', 'er_status_measured_by_ihc'),
            crossing_dim=2*2),    
        ],
    output_mode="concat", 
)


# now that we have specified the preprocessing, let's run it on the data

# create a version of the dataset that can be iterated without labels
train_ds_with_no_labels = ds_train.map(lambda x, _: x)  
feature_space_2.adapt(train_ds_with_no_labels) # inititalize the feature map to this data
# the adapt function allows the model to learn one-hot encoding sizes

# now define a preprocessing operation that returns the processed features
# preprocessed_ds_train = ds_train.map(lambda x, y: (feature_space_2(x), y), 
#                                      num_parallel_calls=tf.data.AUTOTUNE)
# # run it so that we can use the pre-processed data
# preprocessed_ds_train = preprocessed_ds_train.prefetch(tf.data.AUTOTUNE)

# # do the same for the test set
# preprocessed_ds_test = ds_test.map(lambda x, y: (feature_space_2(x), y), num_parallel_calls=tf.data.AUTOTUNE)
# preprocessed_ds_test = preprocessed_ds_test.prefetch(tf.data.AUTOTUNE)
In [ ]:
# from keras.metrics import Precision, Recall
dict_inputs = feature_space_2.get_inputs() # need to use unprocessed features here, to gain access to each output

# we need to create separate lists for each branch
crossed_outputs = []

# for each crossed variable, make an embedding
for col in feature_space_2.crossers.keys():
    
    x = setup_embedding_from_crossing(feature_space_2, col)
    
    # save these outputs in list to concatenate later
    crossed_outputs.append(x)
    

# now concatenate the outputs and add a fully connected layer
wide_branch = Concatenate(name='wide_concat')(crossed_outputs)

# reset this input branch
all_deep_branch_outputs = []

# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
    x = feature_space_2.preprocessors[col].output
    x = tf.cast(x,float) # cast an integer as a float here
    all_deep_branch_outputs.append(x)
    
# for each categorical variable
for col in categorical_headers:
    
    # get the output tensor from ebedding layer
    x = setup_embedding_from_categorical(feature_space_2, col)
    
    # save these outputs in list to concatenate later
    all_deep_branch_outputs.append(x)


# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=10,activation='relu', name='deep3')(deep_branch)
    
# merge the deep and wide branch
final_branch = Concatenate(name='concat_deep_wide')([deep_branch, wide_branch])
final_branch = Dense(units=1,activation='sigmoid',
                     name='combined')(final_branch)

training_model_2 = keras.Model(inputs=dict_inputs, outputs=final_branch)

# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_2.compile(
    optimizer="adam", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)

training_model_2.summary()

plot_model(
    training_model_2, to_file='model.png', show_shapes=True, show_layer_names=True,
    rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_36"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 type_of_breast_surgery (InputL  [(None, 1)]         0           []                               
 ayer)                                                                                            
                                                                                                  
 cancer_type_detailed (InputLay  [(None, 1)]         0           []                               
 er)                                                                                              
                                                                                                  
 cellularity (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 pam50_plus_claudin-low_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 er_status_measured_by_ihc (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 er_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 her2_status_measured_by_snp6 (  [(None, 1)]         0           []                               
 InputLayer)                                                                                      
                                                                                                  
 her2_status (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 tumor_other_histologic_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 inferred_menopausal_state (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 integrative_cluster (InputLaye  [(None, 1)]         0           []                               
 r)                                                                                               
                                                                                                  
 primary_tumor_laterality (Inpu  [(None, 1)]         0           []                               
 tLayer)                                                                                          
                                                                                                  
 oncotree_code (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 pr_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 3-gene_classifier_subtype (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 age_at_diagnosis (InputLayer)  [(None, 1)]          0           []                               
                                                                                                  
 neoplasm_histologic_grade (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 lymph_nodes_examined_positive   [(None, 1)]         0           []                               
 (InputLayer)                                                                                     
                                                                                                  
 mutation_count (InputLayer)    [(None, 1)]          0           []                               
                                                                                                  
 nottingham_prognostic_index (I  [(None, 1)]         0           []                               
 nputLayer)                                                                                       
                                                                                                  
 overall_survival_months (Input  [(None, 1)]         0           []                               
 Layer)                                                                                           
                                                                                                  
 tumor_size (InputLayer)        [(None, 1)]          0           []                               
                                                                                                  
 tumor_stage (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 chemotherapy (InputLayer)      [(None, 1)]          0           []                               
                                                                                                  
 radio_therapy (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 hormone_therapy (InputLayer)   [(None, 1)]          0           []                               
                                                                                                  
 string_categorical_439_preproc  (None, 1)           0           ['type_of_breast_surgery[0][0]'] 
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_440_preproc  (None, 1)           0           ['cancer_type_detailed[0][0]']   
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_441_preproc  (None, 1)           0           ['cellularity[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_442_preproc  (None, 1)           0           ['pam50_plus_claudin-low_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_443_preproc  (None, 1)           0           ['er_status_measured_by_ihc[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_444_preproc  (None, 1)           0           ['er_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_445_preproc  (None, 1)           0           ['her2_status_measured_by_snp6[0]
 essor (StringLookup)                                            [0]']                            
                                                                                                  
 string_categorical_446_preproc  (None, 1)           0           ['her2_status[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_447_preproc  (None, 1)           0           ['tumor_other_histologic_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_448_preproc  (None, 1)           0           ['inferred_menopausal_state[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_449_preproc  (None, 1)           0           ['integrative_cluster[0][0]']    
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_450_preproc  (None, 1)           0           ['primary_tumor_laterality[0][0]'
 essor (StringLookup)                                            ]                                
                                                                                                  
 string_categorical_451_preproc  (None, 1)           0           ['oncotree_code[0][0]']          
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_452_preproc  (None, 1)           0           ['pr_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_453_preproc  (None, 1)           0           ['3-gene_classifier_subtype[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 float_normalized_317_preproces  (None, 1)           3           ['age_at_diagnosis[0][0]']       
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_318_preproces  (None, 1)           3           ['neoplasm_histologic_grade[0][0]
 sor (Normalization)                                             ']                               
                                                                                                  
 float_normalized_319_preproces  (None, 1)           3           ['lymph_nodes_examined_positive[0
 sor (Normalization)                                             ][0]']                           
                                                                                                  
 float_normalized_320_preproces  (None, 1)           3           ['mutation_count[0][0]']         
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_321_preproces  (None, 1)           3           ['nottingham_prognostic_index[0][
 sor (Normalization)                                             0]']                             
                                                                                                  
 float_normalized_322_preproces  (None, 1)           3           ['overall_survival_months[0][0]']
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_323_preproces  (None, 1)           3           ['tumor_size[0][0]']             
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_324_preproces  (None, 1)           3           ['tumor_stage[0][0]']            
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_314_preproces  (None, 1)           3           ['chemotherapy[0][0]']           
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_316_preproces  (None, 1)           3           ['radio_therapy[0][0]']          
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_315_preproces  (None, 1)           3           ['hormone_therapy[0][0]']        
 sor (Normalization)                                                                              
                                                                                                  
 type_of_breast_surgery_embed (  (None, 1, 1)        2           ['string_categorical_439_preproce
 Embedding)                                                      ssor[0][0]']                     
                                                                                                  
 cancer_type_detailed_embed (Em  (None, 1, 2)        10          ['string_categorical_440_preproce
 bedding)                                                        ssor[0][0]']                     
                                                                                                  
 cellularity_embed (Embedding)  (None, 1, 1)         3           ['string_categorical_441_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 pam50_plus_claudin-low_subtype  (None, 1, 2)        14          ['string_categorical_442_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 er_status_measured_by_ihc_embe  (None, 1, 1)        2           ['string_categorical_443_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 er_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_444_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 her2_status_measured_by_snp6_e  (None, 1, 2)        8           ['string_categorical_445_preproce
 mbed (Embedding)                                                ssor[0][0]']                     
                                                                                                  
 her2_status_embed (Embedding)  (None, 1, 1)         2           ['string_categorical_446_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 tumor_other_histologic_subtype  (None, 1, 2)        14          ['string_categorical_447_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 inferred_menopausal_state_embe  (None, 1, 1)        2           ['string_categorical_448_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 integrative_cluster_embed (Emb  (None, 1, 3)        33          ['string_categorical_449_preproce
 edding)                                                         ssor[0][0]']                     
                                                                                                  
 primary_tumor_laterality_embed  (None, 1, 1)        2           ['string_categorical_450_preproce
  (Embedding)                                                    ssor[0][0]']                     
                                                                                                  
 oncotree_code_embed (Embedding  (None, 1, 2)        10          ['string_categorical_451_preproce
 )                                                               ssor[0][0]']                     
                                                                                                  
 pr_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_452_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_embe  (None, 1, 2)        8           ['string_categorical_453_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 tf.cast_396 (TFOpLambda)       (None, 1)            0           ['float_normalized_317_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_397 (TFOpLambda)       (None, 1)            0           ['float_normalized_318_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_398 (TFOpLambda)       (None, 1)            0           ['float_normalized_319_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_399 (TFOpLambda)       (None, 1)            0           ['float_normalized_320_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_400 (TFOpLambda)       (None, 1)            0           ['float_normalized_321_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_401 (TFOpLambda)       (None, 1)            0           ['float_normalized_322_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_402 (TFOpLambda)       (None, 1)            0           ['float_normalized_323_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_403 (TFOpLambda)       (None, 1)            0           ['float_normalized_324_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_404 (TFOpLambda)       (None, 1)            0           ['float_normalized_314_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_405 (TFOpLambda)       (None, 1)            0           ['float_normalized_316_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_406 (TFOpLambda)       (None, 1)            0           ['float_normalized_315_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 flatten_638 (Flatten)          (None, 1)            0           ['type_of_breast_surgery_embed[0]
                                                                 [0]']                            
                                                                                                  
 flatten_639 (Flatten)          (None, 2)            0           ['cancer_type_detailed_embed[0][0
                                                                 ]']                              
                                                                                                  
 flatten_640 (Flatten)          (None, 1)            0           ['cellularity_embed[0][0]']      
                                                                                                  
 flatten_641 (Flatten)          (None, 2)            0           ['pam50_plus_claudin-low_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_642 (Flatten)          (None, 1)            0           ['er_status_measured_by_ihc_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_643 (Flatten)          (None, 1)            0           ['er_status_embed[0][0]']        
                                                                                                  
 flatten_644 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_em
                                                                 bed[0][0]']                      
                                                                                                  
 flatten_645 (Flatten)          (None, 1)            0           ['her2_status_embed[0][0]']      
                                                                                                  
 flatten_646 (Flatten)          (None, 2)            0           ['tumor_other_histologic_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_647 (Flatten)          (None, 1)            0           ['inferred_menopausal_state_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_648 (Flatten)          (None, 3)            0           ['integrative_cluster_embed[0][0]
                                                                 ']                               
                                                                                                  
 flatten_649 (Flatten)          (None, 1)            0           ['primary_tumor_laterality_embed[
                                                                 0][0]']                          
                                                                                                  
 flatten_650 (Flatten)          (None, 2)            0           ['oncotree_code_embed[0][0]']    
                                                                                                  
 flatten_651 (Flatten)          (None, 1)            0           ['pr_status_embed[0][0]']        
                                                                                                  
 flatten_652 (Flatten)          (None, 2)            0           ['3-gene_classifier_subtype_embed
                                                                 [0][0]']                         
                                                                                                  
 embed_concat (Concatenate)     (None, 34)           0           ['tf.cast_396[0][0]',            
                                                                  'tf.cast_397[0][0]',            
                                                                  'tf.cast_398[0][0]',            
                                                                  'tf.cast_399[0][0]',            
                                                                  'tf.cast_400[0][0]',            
                                                                  'tf.cast_401[0][0]',            
                                                                  'tf.cast_402[0][0]',            
                                                                  'tf.cast_403[0][0]',            
                                                                  'tf.cast_404[0][0]',            
                                                                  'tf.cast_405[0][0]',            
                                                                  'tf.cast_406[0][0]',            
                                                                  'flatten_638[0][0]',            
                                                                  'flatten_639[0][0]',            
                                                                  'flatten_640[0][0]',            
                                                                  'flatten_641[0][0]',            
                                                                  'flatten_642[0][0]',            
                                                                  'flatten_643[0][0]',            
                                                                  'flatten_644[0][0]',            
                                                                  'flatten_645[0][0]',            
                                                                  'flatten_646[0][0]',            
                                                                  'flatten_647[0][0]',            
                                                                  'flatten_648[0][0]',            
                                                                  'flatten_649[0][0]',            
                                                                  'flatten_650[0][0]',            
                                                                  'flatten_651[0][0]',            
                                                                  'flatten_652[0][0]']            
                                                                                                  
 her2_status_measured_by_snp6_X  (None, 1)           0           ['string_categorical_445_preproce
 _her2_status (HashedCrossing)                                   ssor[0][0]',                     
                                                                  'string_categorical_446_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_X_in  (None, 1)           0           ['string_categorical_453_preproce
 tegrative_cluster_X_pam50_plus                                  ssor[0][0]',                     
 _claudin-low_subtype (HashedCr                                   'string_categorical_449_preproce
 ossing)                                                         ssor[0][0]',                     
                                                                  'string_categorical_442_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 er_status_X_er_status_measured  (None, 1)           0           ['string_categorical_444_preproce
 _by_ihc (HashedCrossing)                                        ssor[0][0]',                     
                                                                  'string_categorical_443_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 deep1 (Dense)                  (None, 34)           1190        ['embed_concat[0][0]']           
                                                                                                  
 her2_status_measured_by_snp6_X  (None, 1, 2)        16          ['her2_status_measured_by_snp6_X_
 _her2_status_embed (Embedding)                                  her2_status[0][0]']              
                                                                                                  
 3-gene_classifier_subtype_X_in  (None, 1, 17)       5236        ['3-gene_classifier_subtype_X_int
 tegrative_cluster_X_pam50_plus                                  egrative_cluster_X_pam50_plus_cla
 _claudin-low_subtype_embed (Em                                  udin-low_subtype[0][0]']         
 bedding)                                                                                         
                                                                                                  
 er_status_X_er_status_measured  (None, 1, 2)        8           ['er_status_X_er_status_measured_
 _by_ihc_embed (Embedding)                                       by_ihc[0][0]']                   
                                                                                                  
 deep2 (Dense)                  (None, 17)           595         ['deep1[0][0]']                  
                                                                                                  
 flatten_635 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_X_
                                                                 her2_status_embed[0][0]']        
                                                                                                  
 flatten_636 (Flatten)          (None, 17)           0           ['3-gene_classifier_subtype_X_int
                                                                 egrative_cluster_X_pam50_plus_cla
                                                                 udin-low_subtype_embed[0][0]']   
                                                                                                  
 flatten_637 (Flatten)          (None, 2)            0           ['er_status_X_er_status_measured_
                                                                 by_ihc_embed[0][0]']             
                                                                                                  
 deep3 (Dense)                  (None, 10)           180         ['deep2[0][0]']                  
                                                                                                  
 wide_concat (Concatenate)      (None, 21)           0           ['flatten_635[0][0]',            
                                                                  'flatten_636[0][0]',            
                                                                  'flatten_637[0][0]']            
                                                                                                  
 concat_deep_wide (Concatenate)  (None, 31)          0           ['deep3[0][0]',                  
                                                                  'wide_concat[0][0]']            
                                                                                                  
 combined (Dense)               (None, 1)            32          ['concat_deep_wide[0][0]']       
                                                                                                  
==================================================================================================
Total params: 7,404
Trainable params: 7,371
Non-trainable params: 33
__________________________________________________________________________________________________
Out[ ]:
No description has been provided for this image
In [ ]:
# train using the already processed features
history_2 = training_model_2.fit(
    ds_train, epochs=15, validation_data=ds_test, verbose=2
)
Epoch 1/15
19/19 - 4s - loss: 0.6784 - acc: 0.5670 - f1_m: 0.3342 - precision_m: 0.8136 - recall_m: 0.2542 - val_loss: 0.6557 - val_acc: 0.6667 - val_f1_m: 0.5209 - val_precision_m: 0.9167 - val_recall_m: 0.3925 - 4s/epoch - 231ms/step
Epoch 2/15
19/19 - 0s - loss: 0.6421 - acc: 0.7109 - f1_m: 0.5894 - precision_m: 0.8316 - recall_m: 0.4842 - val_loss: 0.6249 - val_acc: 0.7593 - val_f1_m: 0.7175 - val_precision_m: 0.8542 - val_recall_m: 0.6650 - 76ms/epoch - 4ms/step
Epoch 3/15
19/19 - 0s - loss: 0.6029 - acc: 0.7481 - f1_m: 0.7024 - precision_m: 0.8116 - recall_m: 0.6567 - val_loss: 0.5847 - val_acc: 0.7654 - val_f1_m: 0.7393 - val_precision_m: 0.8210 - val_recall_m: 0.7158 - 77ms/epoch - 4ms/step
Epoch 4/15
19/19 - 0s - loss: 0.5591 - acc: 0.7630 - f1_m: 0.7296 - precision_m: 0.7956 - recall_m: 0.7142 - val_loss: 0.5429 - val_acc: 0.7840 - val_f1_m: 0.7697 - val_precision_m: 0.8240 - val_recall_m: 0.7667 - 77ms/epoch - 4ms/step
Epoch 5/15
19/19 - 0s - loss: 0.5192 - acc: 0.7742 - f1_m: 0.7510 - precision_m: 0.7916 - recall_m: 0.7539 - val_loss: 0.5037 - val_acc: 0.8025 - val_f1_m: 0.7893 - val_precision_m: 0.8267 - val_recall_m: 0.7992 - 72ms/epoch - 4ms/step
Epoch 6/15
19/19 - 0s - loss: 0.4864 - acc: 0.7866 - f1_m: 0.7675 - precision_m: 0.7946 - recall_m: 0.7805 - val_loss: 0.4720 - val_acc: 0.8148 - val_f1_m: 0.8076 - val_precision_m: 0.8264 - val_recall_m: 0.8300 - 70ms/epoch - 4ms/step
Epoch 7/15
19/19 - 0s - loss: 0.4617 - acc: 0.7953 - f1_m: 0.7789 - precision_m: 0.7933 - recall_m: 0.8043 - val_loss: 0.4499 - val_acc: 0.8395 - val_f1_m: 0.8327 - val_precision_m: 0.8368 - val_recall_m: 0.8600 - 73ms/epoch - 4ms/step
Epoch 8/15
19/19 - 0s - loss: 0.4428 - acc: 0.8052 - f1_m: 0.7904 - precision_m: 0.7979 - recall_m: 0.8174 - val_loss: 0.4375 - val_acc: 0.8272 - val_f1_m: 0.8202 - val_precision_m: 0.8283 - val_recall_m: 0.8500 - 71ms/epoch - 4ms/step
Epoch 9/15
19/19 - 0s - loss: 0.4278 - acc: 0.8040 - f1_m: 0.7898 - precision_m: 0.7956 - recall_m: 0.8190 - val_loss: 0.4319 - val_acc: 0.8333 - val_f1_m: 0.8270 - val_precision_m: 0.8297 - val_recall_m: 0.8600 - 71ms/epoch - 4ms/step
Epoch 10/15
19/19 - 0s - loss: 0.4159 - acc: 0.8052 - f1_m: 0.7918 - precision_m: 0.7971 - recall_m: 0.8228 - val_loss: 0.4308 - val_acc: 0.8272 - val_f1_m: 0.8215 - val_precision_m: 0.8292 - val_recall_m: 0.8500 - 73ms/epoch - 4ms/step
Epoch 11/15
19/19 - 0s - loss: 0.4055 - acc: 0.8151 - f1_m: 0.8033 - precision_m: 0.8021 - recall_m: 0.8333 - val_loss: 0.4319 - val_acc: 0.8272 - val_f1_m: 0.8232 - val_precision_m: 0.8302 - val_recall_m: 0.8525 - 73ms/epoch - 4ms/step
Epoch 12/15
19/19 - 0s - loss: 0.3966 - acc: 0.8176 - f1_m: 0.8045 - precision_m: 0.8036 - recall_m: 0.8320 - val_loss: 0.4338 - val_acc: 0.8272 - val_f1_m: 0.8232 - val_precision_m: 0.8302 - val_recall_m: 0.8525 - 72ms/epoch - 4ms/step
Epoch 13/15
19/19 - 0s - loss: 0.3889 - acc: 0.8263 - f1_m: 0.8126 - precision_m: 0.8128 - recall_m: 0.8363 - val_loss: 0.4359 - val_acc: 0.8210 - val_f1_m: 0.8164 - val_precision_m: 0.8289 - val_recall_m: 0.8425 - 70ms/epoch - 4ms/step
Epoch 14/15
19/19 - 0s - loss: 0.3817 - acc: 0.8313 - f1_m: 0.8194 - precision_m: 0.8159 - recall_m: 0.8454 - val_loss: 0.4382 - val_acc: 0.8148 - val_f1_m: 0.8110 - val_precision_m: 0.8224 - val_recall_m: 0.8425 - 72ms/epoch - 4ms/step
Epoch 15/15
19/19 - 0s - loss: 0.3750 - acc: 0.8313 - f1_m: 0.8196 - precision_m: 0.8159 - recall_m: 0.8456 - val_loss: 0.4403 - val_acc: 0.8148 - val_f1_m: 0.8110 - val_precision_m: 0.8224 - val_recall_m: 0.8425 - 71ms/epoch - 4ms/step
In [ ]:
from matplotlib import pyplot as plt

%matplotlib inline

plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_2.history['f1_m'])

plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_2.history['val_f1_m'])
plt.title('Validation')

plt.subplot(2,2,3)
plt.plot(history_2.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')

plt.subplot(2,2,4)
plt.plot(history_2.history['val_loss'])
plt.xlabel('epochs')
Out[ ]:
Text(0.5, 0, 'epochs')
No description has been provided for this image
In [ ]:
#  Vizualize some metrics associated with this model
# Source: Modified from in-class lecture

# now lets see how well the model performed
yhat_proba_2 = training_model_2.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions 
yhat_2 = np.round(yhat_proba_2.squeeze()) # round to get binary class

conf_mat_2 = mt.confusion_matrix(y_test, yhat_2)

print(conf_mat_2)
print(mt.classification_report(y_test,yhat_2))

# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309.  VitalBook file.
# Create pandas dataframe
conf_df_2 = pd.DataFrame(conf_mat_2, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())

# Create heatmap
sns.heatmap(conf_df_2, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 1s 3ms/step
[[65 15]
 [15 67]]
              precision    recall  f1-score   support

           0       0.81      0.81      0.81        80
           1       0.82      0.82      0.82        82

    accuracy                           0.81       162
   macro avg       0.81      0.81      0.81       162
weighted avg       0.81      0.81      0.81       162

No description has been provided for this image

With a little trial and error, I adjusted the epoch count down to 15 compared to 25 with model 1. On most runs I see overtraining start to occur after 9-15 epochs with my training loss on my validation data starting to trend upward. Due to this, I can say generally I'm seeing this model converge in fewer epochs. Also, I have an F1 score just slightly better than in model 1. My confusion matrix is roughly the same.

Model 3 of 3¶

Here, I'll go back to my original cross-categorical features (feature_space_1) and try changing my optimization method to see what effect that has on the results. For this network I'll use RMSProp instead of ADAM.

In [ ]:
# from keras.metrics import Precision, Recall
dict_inputs = feature_space_1.get_inputs() # need to use unprocessed features here, to gain access to each output

# we need to create separate lists for each branch
crossed_outputs = []

# for each crossed variable, make an embedding
for col in feature_space_1.crossers.keys():
    
    x = setup_embedding_from_crossing(feature_space_1, col)
    
    # save these outputs in list to concatenate later
    crossed_outputs.append(x)
    

# now concatenate the outputs and add a fully connected layer
wide_branch = Concatenate(name='wide_concat')(crossed_outputs)

# reset this input branch
all_deep_branch_outputs = []

# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
    x = feature_space_1.preprocessors[col].output
    x = tf.cast(x,float) # cast an integer as a float here
    all_deep_branch_outputs.append(x)
    
# for each categorical variable
for col in categorical_headers:
    
    # get the output tensor from ebedding layer
    x = setup_embedding_from_categorical(feature_space_1, col)
    
    # save these outputs in list to concatenate later
    all_deep_branch_outputs.append(x)


# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=10,activation='relu', name='deep3')(deep_branch)
    
# merge the deep and wide branch
final_branch = Concatenate(name='concat_deep_wide')([deep_branch, wide_branch])
final_branch = Dense(units=1,activation='sigmoid',
                     name='combined')(final_branch)

training_model_3 = keras.Model(inputs=dict_inputs, outputs=final_branch)

# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_3.compile(
    optimizer="RMSProp", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)

training_model_3.summary()

plot_model(
    training_model_3, to_file='model.png', show_shapes=True, show_layer_names=True,
    rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_37"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 type_of_breast_surgery (InputL  [(None, 1)]         0           []                               
 ayer)                                                                                            
                                                                                                  
 cancer_type_detailed (InputLay  [(None, 1)]         0           []                               
 er)                                                                                              
                                                                                                  
 cellularity (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 pam50_plus_claudin-low_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 er_status_measured_by_ihc (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 er_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 her2_status_measured_by_snp6 (  [(None, 1)]         0           []                               
 InputLayer)                                                                                      
                                                                                                  
 her2_status (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 tumor_other_histologic_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 inferred_menopausal_state (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 integrative_cluster (InputLaye  [(None, 1)]         0           []                               
 r)                                                                                               
                                                                                                  
 primary_tumor_laterality (Inpu  [(None, 1)]         0           []                               
 tLayer)                                                                                          
                                                                                                  
 oncotree_code (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 pr_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 3-gene_classifier_subtype (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 age_at_diagnosis (InputLayer)  [(None, 1)]          0           []                               
                                                                                                  
 neoplasm_histologic_grade (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 lymph_nodes_examined_positive   [(None, 1)]         0           []                               
 (InputLayer)                                                                                     
                                                                                                  
 mutation_count (InputLayer)    [(None, 1)]          0           []                               
                                                                                                  
 nottingham_prognostic_index (I  [(None, 1)]         0           []                               
 nputLayer)                                                                                       
                                                                                                  
 overall_survival_months (Input  [(None, 1)]         0           []                               
 Layer)                                                                                           
                                                                                                  
 tumor_size (InputLayer)        [(None, 1)]          0           []                               
                                                                                                  
 tumor_stage (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 chemotherapy (InputLayer)      [(None, 1)]          0           []                               
                                                                                                  
 radio_therapy (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 hormone_therapy (InputLayer)   [(None, 1)]          0           []                               
                                                                                                  
 string_categorical_424_preproc  (None, 1)           0           ['type_of_breast_surgery[0][0]'] 
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_425_preproc  (None, 1)           0           ['cancer_type_detailed[0][0]']   
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_426_preproc  (None, 1)           0           ['cellularity[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_427_preproc  (None, 1)           0           ['pam50_plus_claudin-low_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_428_preproc  (None, 1)           0           ['er_status_measured_by_ihc[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_429_preproc  (None, 1)           0           ['er_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_430_preproc  (None, 1)           0           ['her2_status_measured_by_snp6[0]
 essor (StringLookup)                                            [0]']                            
                                                                                                  
 string_categorical_431_preproc  (None, 1)           0           ['her2_status[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_432_preproc  (None, 1)           0           ['tumor_other_histologic_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_433_preproc  (None, 1)           0           ['inferred_menopausal_state[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_434_preproc  (None, 1)           0           ['integrative_cluster[0][0]']    
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_435_preproc  (None, 1)           0           ['primary_tumor_laterality[0][0]'
 essor (StringLookup)                                            ]                                
                                                                                                  
 string_categorical_436_preproc  (None, 1)           0           ['oncotree_code[0][0]']          
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_437_preproc  (None, 1)           0           ['pr_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_438_preproc  (None, 1)           0           ['3-gene_classifier_subtype[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 float_normalized_306_preproces  (None, 1)           3           ['age_at_diagnosis[0][0]']       
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_307_preproces  (None, 1)           3           ['neoplasm_histologic_grade[0][0]
 sor (Normalization)                                             ']                               
                                                                                                  
 float_normalized_308_preproces  (None, 1)           3           ['lymph_nodes_examined_positive[0
 sor (Normalization)                                             ][0]']                           
                                                                                                  
 float_normalized_309_preproces  (None, 1)           3           ['mutation_count[0][0]']         
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_310_preproces  (None, 1)           3           ['nottingham_prognostic_index[0][
 sor (Normalization)                                             0]']                             
                                                                                                  
 float_normalized_311_preproces  (None, 1)           3           ['overall_survival_months[0][0]']
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_312_preproces  (None, 1)           3           ['tumor_size[0][0]']             
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_313_preproces  (None, 1)           3           ['tumor_stage[0][0]']            
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_303_preproces  (None, 1)           3           ['chemotherapy[0][0]']           
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_305_preproces  (None, 1)           3           ['radio_therapy[0][0]']          
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_304_preproces  (None, 1)           3           ['hormone_therapy[0][0]']        
 sor (Normalization)                                                                              
                                                                                                  
 type_of_breast_surgery_embed (  (None, 1, 1)        2           ['string_categorical_424_preproce
 Embedding)                                                      ssor[0][0]']                     
                                                                                                  
 cancer_type_detailed_embed (Em  (None, 1, 2)        10          ['string_categorical_425_preproce
 bedding)                                                        ssor[0][0]']                     
                                                                                                  
 cellularity_embed (Embedding)  (None, 1, 1)         3           ['string_categorical_426_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 pam50_plus_claudin-low_subtype  (None, 1, 2)        14          ['string_categorical_427_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 er_status_measured_by_ihc_embe  (None, 1, 1)        2           ['string_categorical_428_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 er_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_429_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 her2_status_measured_by_snp6_e  (None, 1, 2)        8           ['string_categorical_430_preproce
 mbed (Embedding)                                                ssor[0][0]']                     
                                                                                                  
 her2_status_embed (Embedding)  (None, 1, 1)         2           ['string_categorical_431_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 tumor_other_histologic_subtype  (None, 1, 2)        14          ['string_categorical_432_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 inferred_menopausal_state_embe  (None, 1, 1)        2           ['string_categorical_433_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 integrative_cluster_embed (Emb  (None, 1, 3)        33          ['string_categorical_434_preproce
 edding)                                                         ssor[0][0]']                     
                                                                                                  
 primary_tumor_laterality_embed  (None, 1, 1)        2           ['string_categorical_435_preproce
  (Embedding)                                                    ssor[0][0]']                     
                                                                                                  
 oncotree_code_embed (Embedding  (None, 1, 2)        10          ['string_categorical_436_preproce
 )                                                               ssor[0][0]']                     
                                                                                                  
 pr_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_437_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_embe  (None, 1, 2)        8           ['string_categorical_438_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 tf.cast_407 (TFOpLambda)       (None, 1)            0           ['float_normalized_306_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_408 (TFOpLambda)       (None, 1)            0           ['float_normalized_307_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_409 (TFOpLambda)       (None, 1)            0           ['float_normalized_308_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_410 (TFOpLambda)       (None, 1)            0           ['float_normalized_309_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_411 (TFOpLambda)       (None, 1)            0           ['float_normalized_310_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_412 (TFOpLambda)       (None, 1)            0           ['float_normalized_311_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_413 (TFOpLambda)       (None, 1)            0           ['float_normalized_312_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_414 (TFOpLambda)       (None, 1)            0           ['float_normalized_313_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_415 (TFOpLambda)       (None, 1)            0           ['float_normalized_303_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_416 (TFOpLambda)       (None, 1)            0           ['float_normalized_305_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_417 (TFOpLambda)       (None, 1)            0           ['float_normalized_304_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 flatten_656 (Flatten)          (None, 1)            0           ['type_of_breast_surgery_embed[0]
                                                                 [0]']                            
                                                                                                  
 flatten_657 (Flatten)          (None, 2)            0           ['cancer_type_detailed_embed[0][0
                                                                 ]']                              
                                                                                                  
 flatten_658 (Flatten)          (None, 1)            0           ['cellularity_embed[0][0]']      
                                                                                                  
 flatten_659 (Flatten)          (None, 2)            0           ['pam50_plus_claudin-low_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_660 (Flatten)          (None, 1)            0           ['er_status_measured_by_ihc_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_661 (Flatten)          (None, 1)            0           ['er_status_embed[0][0]']        
                                                                                                  
 flatten_662 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_em
                                                                 bed[0][0]']                      
                                                                                                  
 flatten_663 (Flatten)          (None, 1)            0           ['her2_status_embed[0][0]']      
                                                                                                  
 flatten_664 (Flatten)          (None, 2)            0           ['tumor_other_histologic_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_665 (Flatten)          (None, 1)            0           ['inferred_menopausal_state_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_666 (Flatten)          (None, 3)            0           ['integrative_cluster_embed[0][0]
                                                                 ']                               
                                                                                                  
 flatten_667 (Flatten)          (None, 1)            0           ['primary_tumor_laterality_embed[
                                                                 0][0]']                          
                                                                                                  
 flatten_668 (Flatten)          (None, 2)            0           ['oncotree_code_embed[0][0]']    
                                                                                                  
 flatten_669 (Flatten)          (None, 1)            0           ['pr_status_embed[0][0]']        
                                                                                                  
 flatten_670 (Flatten)          (None, 2)            0           ['3-gene_classifier_subtype_embed
                                                                 [0][0]']                         
                                                                                                  
 embed_concat (Concatenate)     (None, 34)           0           ['tf.cast_407[0][0]',            
                                                                  'tf.cast_408[0][0]',            
                                                                  'tf.cast_409[0][0]',            
                                                                  'tf.cast_410[0][0]',            
                                                                  'tf.cast_411[0][0]',            
                                                                  'tf.cast_412[0][0]',            
                                                                  'tf.cast_413[0][0]',            
                                                                  'tf.cast_414[0][0]',            
                                                                  'tf.cast_415[0][0]',            
                                                                  'tf.cast_416[0][0]',            
                                                                  'tf.cast_417[0][0]',            
                                                                  'flatten_656[0][0]',            
                                                                  'flatten_657[0][0]',            
                                                                  'flatten_658[0][0]',            
                                                                  'flatten_659[0][0]',            
                                                                  'flatten_660[0][0]',            
                                                                  'flatten_661[0][0]',            
                                                                  'flatten_662[0][0]',            
                                                                  'flatten_663[0][0]',            
                                                                  'flatten_664[0][0]',            
                                                                  'flatten_665[0][0]',            
                                                                  'flatten_666[0][0]',            
                                                                  'flatten_667[0][0]',            
                                                                  'flatten_668[0][0]',            
                                                                  'flatten_669[0][0]',            
                                                                  'flatten_670[0][0]']            
                                                                                                  
 her2_status_measured_by_snp6_X  (None, 1)           0           ['string_categorical_430_preproce
 _her2_status (HashedCrossing)                                   ssor[0][0]',                     
                                                                  'string_categorical_431_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_X_in  (None, 1)           0           ['string_categorical_438_preproce
 tegrative_cluster (HashedCross                                  ssor[0][0]',                     
 ing)                                                             'string_categorical_434_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 er_status_X_er_status_measured  (None, 1)           0           ['string_categorical_429_preproce
 _by_ihc (HashedCrossing)                                        ssor[0][0]',                     
                                                                  'string_categorical_428_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 deep1 (Dense)                  (None, 34)           1190        ['embed_concat[0][0]']           
                                                                                                  
 her2_status_measured_by_snp6_X  (None, 1, 2)        16          ['her2_status_measured_by_snp6_X_
 _her2_status_embed (Embedding)                                  her2_status[0][0]']              
                                                                                                  
 3-gene_classifier_subtype_X_in  (None, 1, 6)        264         ['3-gene_classifier_subtype_X_int
 tegrative_cluster_embed (Embed                                  egrative_cluster[0][0]']         
 ding)                                                                                            
                                                                                                  
 er_status_X_er_status_measured  (None, 1, 2)        8           ['er_status_X_er_status_measured_
 _by_ihc_embed (Embedding)                                       by_ihc[0][0]']                   
                                                                                                  
 deep2 (Dense)                  (None, 17)           595         ['deep1[0][0]']                  
                                                                                                  
 flatten_653 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_X_
                                                                 her2_status_embed[0][0]']        
                                                                                                  
 flatten_654 (Flatten)          (None, 6)            0           ['3-gene_classifier_subtype_X_int
                                                                 egrative_cluster_embed[0][0]']   
                                                                                                  
 flatten_655 (Flatten)          (None, 2)            0           ['er_status_X_er_status_measured_
                                                                 by_ihc_embed[0][0]']             
                                                                                                  
 deep3 (Dense)                  (None, 10)           180         ['deep2[0][0]']                  
                                                                                                  
 wide_concat (Concatenate)      (None, 10)           0           ['flatten_653[0][0]',            
                                                                  'flatten_654[0][0]',            
                                                                  'flatten_655[0][0]']            
                                                                                                  
 concat_deep_wide (Concatenate)  (None, 20)          0           ['deep3[0][0]',                  
                                                                  'wide_concat[0][0]']            
                                                                                                  
 combined (Dense)               (None, 1)            21          ['concat_deep_wide[0][0]']       
                                                                                                  
==================================================================================================
Total params: 2,421
Trainable params: 2,388
Non-trainable params: 33
__________________________________________________________________________________________________
Out[ ]:
No description has been provided for this image
In [ ]:
# train using the already processed features
history_3 = training_model_3.fit(
    ds_train, epochs=20, validation_data=ds_test, verbose=2
)
Epoch 1/20
19/19 - 4s - loss: 0.6477 - acc: 0.6365 - f1_m: 0.4069 - precision_m: 0.8202 - recall_m: 0.2962 - val_loss: 0.6280 - val_acc: 0.7222 - val_f1_m: 0.6945 - val_precision_m: 0.7586 - val_recall_m: 0.6725 - 4s/epoch - 194ms/step
Epoch 2/20
19/19 - 0s - loss: 0.5937 - acc: 0.7320 - f1_m: 0.6863 - precision_m: 0.7916 - recall_m: 0.6527 - val_loss: 0.5864 - val_acc: 0.7346 - val_f1_m: 0.7364 - val_precision_m: 0.7493 - val_recall_m: 0.7842 - 75ms/epoch - 4ms/step
Epoch 3/20
19/19 - 0s - loss: 0.5551 - acc: 0.7680 - f1_m: 0.7347 - precision_m: 0.8019 - recall_m: 0.7231 - val_loss: 0.5488 - val_acc: 0.7407 - val_f1_m: 0.7507 - val_precision_m: 0.7373 - val_recall_m: 0.8142 - 71ms/epoch - 4ms/step
Epoch 4/20
19/19 - 0s - loss: 0.5223 - acc: 0.7742 - f1_m: 0.7449 - precision_m: 0.7896 - recall_m: 0.7459 - val_loss: 0.5182 - val_acc: 0.7593 - val_f1_m: 0.7727 - val_precision_m: 0.7463 - val_recall_m: 0.8550 - 71ms/epoch - 4ms/step
Epoch 5/20
19/19 - 0s - loss: 0.4965 - acc: 0.7841 - f1_m: 0.7599 - precision_m: 0.7909 - recall_m: 0.7688 - val_loss: 0.4942 - val_acc: 0.7716 - val_f1_m: 0.7834 - val_precision_m: 0.7577 - val_recall_m: 0.8650 - 72ms/epoch - 4ms/step
Epoch 6/20
19/19 - 0s - loss: 0.4768 - acc: 0.7953 - f1_m: 0.7732 - precision_m: 0.7942 - recall_m: 0.7877 - val_loss: 0.4766 - val_acc: 0.7778 - val_f1_m: 0.7873 - val_precision_m: 0.7669 - val_recall_m: 0.8650 - 76ms/epoch - 4ms/step
Epoch 7/20
19/19 - 0s - loss: 0.4616 - acc: 0.8002 - f1_m: 0.7782 - precision_m: 0.7968 - recall_m: 0.7952 - val_loss: 0.4642 - val_acc: 0.7901 - val_f1_m: 0.7981 - val_precision_m: 0.7733 - val_recall_m: 0.8750 - 71ms/epoch - 4ms/step
Epoch 8/20
19/19 - 0s - loss: 0.4509 - acc: 0.8027 - f1_m: 0.7833 - precision_m: 0.7959 - recall_m: 0.8088 - val_loss: 0.4563 - val_acc: 0.7840 - val_f1_m: 0.7913 - val_precision_m: 0.7677 - val_recall_m: 0.8650 - 71ms/epoch - 4ms/step
Epoch 9/20
19/19 - 0s - loss: 0.4425 - acc: 0.8065 - f1_m: 0.7859 - precision_m: 0.7980 - recall_m: 0.8089 - val_loss: 0.4517 - val_acc: 0.7840 - val_f1_m: 0.7913 - val_precision_m: 0.7677 - val_recall_m: 0.8650 - 72ms/epoch - 4ms/step
Epoch 10/20
19/19 - 0s - loss: 0.4360 - acc: 0.8052 - f1_m: 0.7848 - precision_m: 0.7956 - recall_m: 0.8089 - val_loss: 0.4487 - val_acc: 0.7901 - val_f1_m: 0.7958 - val_precision_m: 0.7727 - val_recall_m: 0.8650 - 70ms/epoch - 4ms/step
Epoch 11/20
19/19 - 0s - loss: 0.4301 - acc: 0.8089 - f1_m: 0.7915 - precision_m: 0.7989 - recall_m: 0.8164 - val_loss: 0.4471 - val_acc: 0.7963 - val_f1_m: 0.8000 - val_precision_m: 0.7825 - val_recall_m: 0.8650 - 68ms/epoch - 4ms/step
Epoch 12/20
19/19 - 0s - loss: 0.4254 - acc: 0.8089 - f1_m: 0.7922 - precision_m: 0.8020 - recall_m: 0.8164 - val_loss: 0.4464 - val_acc: 0.7963 - val_f1_m: 0.8000 - val_precision_m: 0.7825 - val_recall_m: 0.8650 - 69ms/epoch - 4ms/step
Epoch 13/20
19/19 - 0s - loss: 0.4208 - acc: 0.8127 - f1_m: 0.7960 - precision_m: 0.8053 - recall_m: 0.8192 - val_loss: 0.4459 - val_acc: 0.7963 - val_f1_m: 0.8000 - val_precision_m: 0.7825 - val_recall_m: 0.8650 - 67ms/epoch - 4ms/step
Epoch 14/20
19/19 - 0s - loss: 0.4164 - acc: 0.8176 - f1_m: 0.8040 - precision_m: 0.8068 - recall_m: 0.8320 - val_loss: 0.4459 - val_acc: 0.7963 - val_f1_m: 0.8000 - val_precision_m: 0.7825 - val_recall_m: 0.8650 - 69ms/epoch - 4ms/step
Epoch 15/20
19/19 - 0s - loss: 0.4124 - acc: 0.8238 - f1_m: 0.8100 - precision_m: 0.8134 - recall_m: 0.8344 - val_loss: 0.4459 - val_acc: 0.7901 - val_f1_m: 0.7930 - val_precision_m: 0.7772 - val_recall_m: 0.8550 - 71ms/epoch - 4ms/step
Epoch 16/20
19/19 - 0s - loss: 0.4085 - acc: 0.8263 - f1_m: 0.8136 - precision_m: 0.8211 - recall_m: 0.8357 - val_loss: 0.4460 - val_acc: 0.7963 - val_f1_m: 0.7973 - val_precision_m: 0.7884 - val_recall_m: 0.8550 - 70ms/epoch - 4ms/step
Epoch 17/20
19/19 - 0s - loss: 0.4047 - acc: 0.8313 - f1_m: 0.8171 - precision_m: 0.8262 - recall_m: 0.8354 - val_loss: 0.4462 - val_acc: 0.8025 - val_f1_m: 0.8024 - val_precision_m: 0.7943 - val_recall_m: 0.8550 - 72ms/epoch - 4ms/step
Epoch 18/20
19/19 - 0s - loss: 0.4011 - acc: 0.8300 - f1_m: 0.8162 - precision_m: 0.8244 - recall_m: 0.8354 - val_loss: 0.4465 - val_acc: 0.8025 - val_f1_m: 0.8024 - val_precision_m: 0.7943 - val_recall_m: 0.8550 - 83ms/epoch - 4ms/step
Epoch 19/20
19/19 - 0s - loss: 0.3975 - acc: 0.8325 - f1_m: 0.8198 - precision_m: 0.8237 - recall_m: 0.8427 - val_loss: 0.4467 - val_acc: 0.8025 - val_f1_m: 0.8024 - val_precision_m: 0.7943 - val_recall_m: 0.8550 - 81ms/epoch - 4ms/step
Epoch 20/20
19/19 - 0s - loss: 0.3938 - acc: 0.8362 - f1_m: 0.8230 - precision_m: 0.8286 - recall_m: 0.8427 - val_loss: 0.4471 - val_acc: 0.8025 - val_f1_m: 0.8024 - val_precision_m: 0.7943 - val_recall_m: 0.8550 - 73ms/epoch - 4ms/step
In [ ]:
from matplotlib import pyplot as plt

%matplotlib inline

plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_3.history['f1_m'])

plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_3.history['val_f1_m'])
plt.title('Validation')

plt.subplot(2,2,3)
plt.plot(history_3.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')

plt.subplot(2,2,4)
plt.plot(history_3.history['val_loss'])
plt.xlabel('epochs')
Out[ ]:
Text(0.5, 0, 'epochs')
No description has been provided for this image
In [ ]:
#  Vizualize some metrics associated with this model

# Source: Modified from in-class lecture
# Use the sklearn metrics here, if you want to
# from sklearn import metrics as mt

# y_test = tf.concat([y for x, y in ds_test], axis=0)
# y_test = y_test.numpy()

# now lets see how well the model performed
yhat_proba_3 = training_model_3.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions 
yhat_3 = np.round(yhat_proba_3.squeeze()) # round to get binary class

conf_mat_3 = mt.confusion_matrix(y_test, yhat_3)

print(conf_mat_3)
print(mt.classification_report(y_test,yhat_3))

# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309.  VitalBook file.
# Create pandas dataframe
conf_df_3 = pd.DataFrame(conf_mat_3, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())

# Create heatmap
sns.heatmap(conf_df_3, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 1s 2ms/step
[[62 18]
 [14 68]]
              precision    recall  f1-score   support

           0       0.82      0.78      0.79        80
           1       0.79      0.83      0.81        82

    accuracy                           0.80       162
   macro avg       0.80      0.80      0.80       162
weighted avg       0.80      0.80      0.80       162

No description has been provided for this image

This third model performed very similarly to the first which I found interesting. I have a very similar F1 score and but my confusion matrix is a little worse. I can't draw any real conclusions about the change in optimizer. After running all of these models several times I can say that the results jumped around. I had F1 scores from 80 to 84 across all three models which switched from model to model. Any of these would be viable candidates to carry forward to the next evaluation.

Investigating Generalization Performance¶

For this portion of the lab I will use model 2 of 3 from above, which specifically means I'll be using training_model_2 and feature_space_2. Any model would have likely been appropriate and seemed comparable, however during multiple runs, that model appeared to converge the fastest on the validation data.

Per the instructions, I consider model 2 of 3 above to be one of the two required for this portion of the rubric. Therefore I now need to alter model 2 to see if different performance can be achieved. To do this I'll be adding a layer to the deep portion of the network to step it down a little closer to my final binary output layer. I also changed the number of neurons in the last two layers. Finally I altered the number of epochs after a few test runs showed my model taking longer to converge on validation.

In [ ]:
# from keras.metrics import Precision, Recall
dict_inputs = feature_space_2.get_inputs() # need to use unprocessed features here, to gain access to each output

# we need to create separate lists for each branch
crossed_outputs = []

# for each crossed variable, make an embedding
for col in feature_space_2.crossers.keys():
    
    x = setup_embedding_from_crossing(feature_space_2, col)
    
    # save these outputs in list to concatenate later
    crossed_outputs.append(x)
    

# now concatenate the outputs and add a fully connected layer
wide_branch = Concatenate(name='wide_concat')(crossed_outputs)

# reset this input branch
all_deep_branch_outputs = []

# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
    x = feature_space_2.preprocessors[col].output
    x = tf.cast(x,float) # cast an integer as a float here
    all_deep_branch_outputs.append(x)
    
# for each categorical variable
for col in categorical_headers:
    
    # get the output tensor from ebedding layer
    x = setup_embedding_from_categorical(feature_space_2, col)
    
    # save these outputs in list to concatenate later
    all_deep_branch_outputs.append(x)


# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=8,activation='relu', name='deep3')(deep_branch) # Changed from 10 to 8 neurons
deep_branch = Dense(units=4,activation='relu', name='deep4')(deep_branch) # This is my new layer

    
# merge the deep and wide branch
final_branch = Concatenate(name='concat_deep_wide')([deep_branch, wide_branch])
final_branch = Dense(units=1,activation='sigmoid',
                     name='combined')(final_branch)

training_model_4 = keras.Model(inputs=dict_inputs, outputs=final_branch)

# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_4.compile(
    optimizer="adam", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)

training_model_4.summary()

plot_model(
    training_model_4, to_file='model.png', show_shapes=True, show_layer_names=True,
    rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_38"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 type_of_breast_surgery (InputL  [(None, 1)]         0           []                               
 ayer)                                                                                            
                                                                                                  
 cancer_type_detailed (InputLay  [(None, 1)]         0           []                               
 er)                                                                                              
                                                                                                  
 cellularity (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 pam50_plus_claudin-low_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 er_status_measured_by_ihc (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 er_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 her2_status_measured_by_snp6 (  [(None, 1)]         0           []                               
 InputLayer)                                                                                      
                                                                                                  
 her2_status (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 tumor_other_histologic_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 inferred_menopausal_state (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 integrative_cluster (InputLaye  [(None, 1)]         0           []                               
 r)                                                                                               
                                                                                                  
 primary_tumor_laterality (Inpu  [(None, 1)]         0           []                               
 tLayer)                                                                                          
                                                                                                  
 oncotree_code (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 pr_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 3-gene_classifier_subtype (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 age_at_diagnosis (InputLayer)  [(None, 1)]          0           []                               
                                                                                                  
 neoplasm_histologic_grade (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 lymph_nodes_examined_positive   [(None, 1)]         0           []                               
 (InputLayer)                                                                                     
                                                                                                  
 mutation_count (InputLayer)    [(None, 1)]          0           []                               
                                                                                                  
 nottingham_prognostic_index (I  [(None, 1)]         0           []                               
 nputLayer)                                                                                       
                                                                                                  
 overall_survival_months (Input  [(None, 1)]         0           []                               
 Layer)                                                                                           
                                                                                                  
 tumor_size (InputLayer)        [(None, 1)]          0           []                               
                                                                                                  
 tumor_stage (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 chemotherapy (InputLayer)      [(None, 1)]          0           []                               
                                                                                                  
 radio_therapy (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 hormone_therapy (InputLayer)   [(None, 1)]          0           []                               
                                                                                                  
 string_categorical_439_preproc  (None, 1)           0           ['type_of_breast_surgery[0][0]'] 
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_440_preproc  (None, 1)           0           ['cancer_type_detailed[0][0]']   
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_441_preproc  (None, 1)           0           ['cellularity[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_442_preproc  (None, 1)           0           ['pam50_plus_claudin-low_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_443_preproc  (None, 1)           0           ['er_status_measured_by_ihc[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_444_preproc  (None, 1)           0           ['er_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_445_preproc  (None, 1)           0           ['her2_status_measured_by_snp6[0]
 essor (StringLookup)                                            [0]']                            
                                                                                                  
 string_categorical_446_preproc  (None, 1)           0           ['her2_status[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_447_preproc  (None, 1)           0           ['tumor_other_histologic_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_448_preproc  (None, 1)           0           ['inferred_menopausal_state[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_449_preproc  (None, 1)           0           ['integrative_cluster[0][0]']    
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_450_preproc  (None, 1)           0           ['primary_tumor_laterality[0][0]'
 essor (StringLookup)                                            ]                                
                                                                                                  
 string_categorical_451_preproc  (None, 1)           0           ['oncotree_code[0][0]']          
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_452_preproc  (None, 1)           0           ['pr_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_453_preproc  (None, 1)           0           ['3-gene_classifier_subtype[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 float_normalized_317_preproces  (None, 1)           3           ['age_at_diagnosis[0][0]']       
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_318_preproces  (None, 1)           3           ['neoplasm_histologic_grade[0][0]
 sor (Normalization)                                             ']                               
                                                                                                  
 float_normalized_319_preproces  (None, 1)           3           ['lymph_nodes_examined_positive[0
 sor (Normalization)                                             ][0]']                           
                                                                                                  
 float_normalized_320_preproces  (None, 1)           3           ['mutation_count[0][0]']         
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_321_preproces  (None, 1)           3           ['nottingham_prognostic_index[0][
 sor (Normalization)                                             0]']                             
                                                                                                  
 float_normalized_322_preproces  (None, 1)           3           ['overall_survival_months[0][0]']
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_323_preproces  (None, 1)           3           ['tumor_size[0][0]']             
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_324_preproces  (None, 1)           3           ['tumor_stage[0][0]']            
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_314_preproces  (None, 1)           3           ['chemotherapy[0][0]']           
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_316_preproces  (None, 1)           3           ['radio_therapy[0][0]']          
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_315_preproces  (None, 1)           3           ['hormone_therapy[0][0]']        
 sor (Normalization)                                                                              
                                                                                                  
 type_of_breast_surgery_embed (  (None, 1, 1)        2           ['string_categorical_439_preproce
 Embedding)                                                      ssor[0][0]']                     
                                                                                                  
 cancer_type_detailed_embed (Em  (None, 1, 2)        10          ['string_categorical_440_preproce
 bedding)                                                        ssor[0][0]']                     
                                                                                                  
 cellularity_embed (Embedding)  (None, 1, 1)         3           ['string_categorical_441_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 pam50_plus_claudin-low_subtype  (None, 1, 2)        14          ['string_categorical_442_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 er_status_measured_by_ihc_embe  (None, 1, 1)        2           ['string_categorical_443_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 er_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_444_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 her2_status_measured_by_snp6_e  (None, 1, 2)        8           ['string_categorical_445_preproce
 mbed (Embedding)                                                ssor[0][0]']                     
                                                                                                  
 her2_status_embed (Embedding)  (None, 1, 1)         2           ['string_categorical_446_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 tumor_other_histologic_subtype  (None, 1, 2)        14          ['string_categorical_447_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 inferred_menopausal_state_embe  (None, 1, 1)        2           ['string_categorical_448_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 integrative_cluster_embed (Emb  (None, 1, 3)        33          ['string_categorical_449_preproce
 edding)                                                         ssor[0][0]']                     
                                                                                                  
 primary_tumor_laterality_embed  (None, 1, 1)        2           ['string_categorical_450_preproce
  (Embedding)                                                    ssor[0][0]']                     
                                                                                                  
 oncotree_code_embed (Embedding  (None, 1, 2)        10          ['string_categorical_451_preproce
 )                                                               ssor[0][0]']                     
                                                                                                  
 pr_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_452_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_embe  (None, 1, 2)        8           ['string_categorical_453_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 tf.cast_418 (TFOpLambda)       (None, 1)            0           ['float_normalized_317_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_419 (TFOpLambda)       (None, 1)            0           ['float_normalized_318_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_420 (TFOpLambda)       (None, 1)            0           ['float_normalized_319_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_421 (TFOpLambda)       (None, 1)            0           ['float_normalized_320_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_422 (TFOpLambda)       (None, 1)            0           ['float_normalized_321_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_423 (TFOpLambda)       (None, 1)            0           ['float_normalized_322_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_424 (TFOpLambda)       (None, 1)            0           ['float_normalized_323_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_425 (TFOpLambda)       (None, 1)            0           ['float_normalized_324_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_426 (TFOpLambda)       (None, 1)            0           ['float_normalized_314_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_427 (TFOpLambda)       (None, 1)            0           ['float_normalized_316_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_428 (TFOpLambda)       (None, 1)            0           ['float_normalized_315_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 flatten_674 (Flatten)          (None, 1)            0           ['type_of_breast_surgery_embed[0]
                                                                 [0]']                            
                                                                                                  
 flatten_675 (Flatten)          (None, 2)            0           ['cancer_type_detailed_embed[0][0
                                                                 ]']                              
                                                                                                  
 flatten_676 (Flatten)          (None, 1)            0           ['cellularity_embed[0][0]']      
                                                                                                  
 flatten_677 (Flatten)          (None, 2)            0           ['pam50_plus_claudin-low_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_678 (Flatten)          (None, 1)            0           ['er_status_measured_by_ihc_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_679 (Flatten)          (None, 1)            0           ['er_status_embed[0][0]']        
                                                                                                  
 flatten_680 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_em
                                                                 bed[0][0]']                      
                                                                                                  
 flatten_681 (Flatten)          (None, 1)            0           ['her2_status_embed[0][0]']      
                                                                                                  
 flatten_682 (Flatten)          (None, 2)            0           ['tumor_other_histologic_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_683 (Flatten)          (None, 1)            0           ['inferred_menopausal_state_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_684 (Flatten)          (None, 3)            0           ['integrative_cluster_embed[0][0]
                                                                 ']                               
                                                                                                  
 flatten_685 (Flatten)          (None, 1)            0           ['primary_tumor_laterality_embed[
                                                                 0][0]']                          
                                                                                                  
 flatten_686 (Flatten)          (None, 2)            0           ['oncotree_code_embed[0][0]']    
                                                                                                  
 flatten_687 (Flatten)          (None, 1)            0           ['pr_status_embed[0][0]']        
                                                                                                  
 flatten_688 (Flatten)          (None, 2)            0           ['3-gene_classifier_subtype_embed
                                                                 [0][0]']                         
                                                                                                  
 embed_concat (Concatenate)     (None, 34)           0           ['tf.cast_418[0][0]',            
                                                                  'tf.cast_419[0][0]',            
                                                                  'tf.cast_420[0][0]',            
                                                                  'tf.cast_421[0][0]',            
                                                                  'tf.cast_422[0][0]',            
                                                                  'tf.cast_423[0][0]',            
                                                                  'tf.cast_424[0][0]',            
                                                                  'tf.cast_425[0][0]',            
                                                                  'tf.cast_426[0][0]',            
                                                                  'tf.cast_427[0][0]',            
                                                                  'tf.cast_428[0][0]',            
                                                                  'flatten_674[0][0]',            
                                                                  'flatten_675[0][0]',            
                                                                  'flatten_676[0][0]',            
                                                                  'flatten_677[0][0]',            
                                                                  'flatten_678[0][0]',            
                                                                  'flatten_679[0][0]',            
                                                                  'flatten_680[0][0]',            
                                                                  'flatten_681[0][0]',            
                                                                  'flatten_682[0][0]',            
                                                                  'flatten_683[0][0]',            
                                                                  'flatten_684[0][0]',            
                                                                  'flatten_685[0][0]',            
                                                                  'flatten_686[0][0]',            
                                                                  'flatten_687[0][0]',            
                                                                  'flatten_688[0][0]']            
                                                                                                  
 deep1 (Dense)                  (None, 34)           1190        ['embed_concat[0][0]']           
                                                                                                  
 her2_status_measured_by_snp6_X  (None, 1)           0           ['string_categorical_445_preproce
 _her2_status (HashedCrossing)                                   ssor[0][0]',                     
                                                                  'string_categorical_446_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_X_in  (None, 1)           0           ['string_categorical_453_preproce
 tegrative_cluster_X_pam50_plus                                  ssor[0][0]',                     
 _claudin-low_subtype (HashedCr                                   'string_categorical_449_preproce
 ossing)                                                         ssor[0][0]',                     
                                                                  'string_categorical_442_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 er_status_X_er_status_measured  (None, 1)           0           ['string_categorical_444_preproce
 _by_ihc (HashedCrossing)                                        ssor[0][0]',                     
                                                                  'string_categorical_443_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 deep2 (Dense)                  (None, 17)           595         ['deep1[0][0]']                  
                                                                                                  
 her2_status_measured_by_snp6_X  (None, 1, 2)        16          ['her2_status_measured_by_snp6_X_
 _her2_status_embed (Embedding)                                  her2_status[0][0]']              
                                                                                                  
 3-gene_classifier_subtype_X_in  (None, 1, 17)       5236        ['3-gene_classifier_subtype_X_int
 tegrative_cluster_X_pam50_plus                                  egrative_cluster_X_pam50_plus_cla
 _claudin-low_subtype_embed (Em                                  udin-low_subtype[0][0]']         
 bedding)                                                                                         
                                                                                                  
 er_status_X_er_status_measured  (None, 1, 2)        8           ['er_status_X_er_status_measured_
 _by_ihc_embed (Embedding)                                       by_ihc[0][0]']                   
                                                                                                  
 deep3 (Dense)                  (None, 8)            144         ['deep2[0][0]']                  
                                                                                                  
 flatten_671 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_X_
                                                                 her2_status_embed[0][0]']        
                                                                                                  
 flatten_672 (Flatten)          (None, 17)           0           ['3-gene_classifier_subtype_X_int
                                                                 egrative_cluster_X_pam50_plus_cla
                                                                 udin-low_subtype_embed[0][0]']   
                                                                                                  
 flatten_673 (Flatten)          (None, 2)            0           ['er_status_X_er_status_measured_
                                                                 by_ihc_embed[0][0]']             
                                                                                                  
 deep4 (Dense)                  (None, 4)            36          ['deep3[0][0]']                  
                                                                                                  
 wide_concat (Concatenate)      (None, 21)           0           ['flatten_671[0][0]',            
                                                                  'flatten_672[0][0]',            
                                                                  'flatten_673[0][0]']            
                                                                                                  
 concat_deep_wide (Concatenate)  (None, 25)          0           ['deep4[0][0]',                  
                                                                  'wide_concat[0][0]']            
                                                                                                  
 combined (Dense)               (None, 1)            26          ['concat_deep_wide[0][0]']       
                                                                                                  
==================================================================================================
Total params: 7,398
Trainable params: 7,365
Non-trainable params: 33
__________________________________________________________________________________________________
Out[ ]:
No description has been provided for this image
In [ ]:
# train using the already processed features
history_4 = training_model_4.fit(
    ds_train, epochs=25, validation_data=ds_test, verbose=2
) # changed number of epochs
Epoch 1/25
19/19 - 4s - loss: 0.6965 - acc: 0.5261 - f1_m: 0.1033 - precision_m: 0.5263 - recall_m: 0.0598 - val_loss: 0.6834 - val_acc: 0.5370 - val_f1_m: 0.1642 - val_precision_m: 0.7500 - val_recall_m: 0.0942 - 4s/epoch - 236ms/step
Epoch 2/25
19/19 - 0s - loss: 0.6697 - acc: 0.6154 - f1_m: 0.3665 - precision_m: 0.8351 - recall_m: 0.2450 - val_loss: 0.6695 - val_acc: 0.6296 - val_f1_m: 0.4410 - val_precision_m: 0.9167 - val_recall_m: 0.2992 - 75ms/epoch - 4ms/step
Epoch 3/25
19/19 - 0s - loss: 0.6517 - acc: 0.6675 - f1_m: 0.5365 - precision_m: 0.8302 - recall_m: 0.4060 - val_loss: 0.6542 - val_acc: 0.6728 - val_f1_m: 0.5523 - val_precision_m: 0.8750 - val_recall_m: 0.4200 - 76ms/epoch - 4ms/step
Epoch 4/25
19/19 - 0s - loss: 0.6305 - acc: 0.7097 - f1_m: 0.6119 - precision_m: 0.8386 - recall_m: 0.4940 - val_loss: 0.6339 - val_acc: 0.7037 - val_f1_m: 0.6318 - val_precision_m: 0.8495 - val_recall_m: 0.5342 - 77ms/epoch - 4ms/step
Epoch 5/25
19/19 - 0s - loss: 0.6040 - acc: 0.7357 - f1_m: 0.6734 - precision_m: 0.8362 - recall_m: 0.5748 - val_loss: 0.6104 - val_acc: 0.7346 - val_f1_m: 0.6750 - val_precision_m: 0.8512 - val_recall_m: 0.5917 - 76ms/epoch - 4ms/step
Epoch 6/25
19/19 - 0s - loss: 0.5739 - acc: 0.7605 - f1_m: 0.7165 - precision_m: 0.8371 - recall_m: 0.6437 - val_loss: 0.5806 - val_acc: 0.7778 - val_f1_m: 0.7432 - val_precision_m: 0.8598 - val_recall_m: 0.6750 - 77ms/epoch - 4ms/step
Epoch 7/25
19/19 - 0s - loss: 0.5397 - acc: 0.7816 - f1_m: 0.7522 - precision_m: 0.8293 - recall_m: 0.7059 - val_loss: 0.5449 - val_acc: 0.7840 - val_f1_m: 0.7619 - val_precision_m: 0.8438 - val_recall_m: 0.7183 - 76ms/epoch - 4ms/step
Epoch 8/25
19/19 - 0s - loss: 0.5016 - acc: 0.7965 - f1_m: 0.7764 - precision_m: 0.8147 - recall_m: 0.7641 - val_loss: 0.5091 - val_acc: 0.7901 - val_f1_m: 0.7816 - val_precision_m: 0.8083 - val_recall_m: 0.7892 - 76ms/epoch - 4ms/step
Epoch 9/25
19/19 - 0s - loss: 0.4673 - acc: 0.8077 - f1_m: 0.7961 - precision_m: 0.8095 - recall_m: 0.8079 - val_loss: 0.4777 - val_acc: 0.7901 - val_f1_m: 0.7859 - val_precision_m: 0.7884 - val_recall_m: 0.8192 - 76ms/epoch - 4ms/step
Epoch 10/25
19/19 - 0s - loss: 0.4413 - acc: 0.8139 - f1_m: 0.8045 - precision_m: 0.8074 - recall_m: 0.8302 - val_loss: 0.4569 - val_acc: 0.7901 - val_f1_m: 0.7859 - val_precision_m: 0.7884 - val_recall_m: 0.8192 - 72ms/epoch - 4ms/step
Epoch 11/25
19/19 - 0s - loss: 0.4229 - acc: 0.8201 - f1_m: 0.8111 - precision_m: 0.8131 - recall_m: 0.8386 - val_loss: 0.4456 - val_acc: 0.8025 - val_f1_m: 0.8000 - val_precision_m: 0.8046 - val_recall_m: 0.8400 - 78ms/epoch - 4ms/step
Epoch 12/25
19/19 - 0s - loss: 0.4097 - acc: 0.8263 - f1_m: 0.8172 - precision_m: 0.8140 - recall_m: 0.8462 - val_loss: 0.4415 - val_acc: 0.8025 - val_f1_m: 0.8000 - val_precision_m: 0.8046 - val_recall_m: 0.8400 - 79ms/epoch - 4ms/step
Epoch 13/25
19/19 - 0s - loss: 0.3971 - acc: 0.8300 - f1_m: 0.8195 - precision_m: 0.8155 - recall_m: 0.8487 - val_loss: 0.4404 - val_acc: 0.8025 - val_f1_m: 0.8000 - val_precision_m: 0.8046 - val_recall_m: 0.8400 - 79ms/epoch - 4ms/step
Epoch 14/25
19/19 - 0s - loss: 0.3886 - acc: 0.8325 - f1_m: 0.8231 - precision_m: 0.8169 - recall_m: 0.8555 - val_loss: 0.4404 - val_acc: 0.7963 - val_f1_m: 0.7952 - val_precision_m: 0.7992 - val_recall_m: 0.8400 - 79ms/epoch - 4ms/step
Epoch 15/25
19/19 - 0s - loss: 0.3800 - acc: 0.8362 - f1_m: 0.8263 - precision_m: 0.8213 - recall_m: 0.8555 - val_loss: 0.4415 - val_acc: 0.7963 - val_f1_m: 0.7952 - val_precision_m: 0.7992 - val_recall_m: 0.8400 - 80ms/epoch - 4ms/step
Epoch 16/25
19/19 - 0s - loss: 0.3733 - acc: 0.8400 - f1_m: 0.8318 - precision_m: 0.8223 - recall_m: 0.8648 - val_loss: 0.4431 - val_acc: 0.7963 - val_f1_m: 0.7948 - val_precision_m: 0.7941 - val_recall_m: 0.8400 - 76ms/epoch - 4ms/step
Epoch 17/25
19/19 - 0s - loss: 0.3664 - acc: 0.8437 - f1_m: 0.8351 - precision_m: 0.8261 - recall_m: 0.8665 - val_loss: 0.4448 - val_acc: 0.7963 - val_f1_m: 0.7948 - val_precision_m: 0.7941 - val_recall_m: 0.8400 - 75ms/epoch - 4ms/step
Epoch 18/25
19/19 - 0s - loss: 0.3604 - acc: 0.8462 - f1_m: 0.8393 - precision_m: 0.8304 - recall_m: 0.8707 - val_loss: 0.4468 - val_acc: 0.7963 - val_f1_m: 0.7948 - val_precision_m: 0.7941 - val_recall_m: 0.8400 - 78ms/epoch - 4ms/step
Epoch 19/25
19/19 - 0s - loss: 0.3545 - acc: 0.8511 - f1_m: 0.8438 - precision_m: 0.8363 - recall_m: 0.8733 - val_loss: 0.4489 - val_acc: 0.8025 - val_f1_m: 0.7999 - val_precision_m: 0.8000 - val_recall_m: 0.8400 - 72ms/epoch - 4ms/step
Epoch 20/25
19/19 - 0s - loss: 0.3490 - acc: 0.8524 - f1_m: 0.8448 - precision_m: 0.8380 - recall_m: 0.8733 - val_loss: 0.4511 - val_acc: 0.8025 - val_f1_m: 0.7999 - val_precision_m: 0.8000 - val_recall_m: 0.8400 - 76ms/epoch - 4ms/step
Epoch 21/25
19/19 - 0s - loss: 0.3441 - acc: 0.8610 - f1_m: 0.8528 - precision_m: 0.8473 - recall_m: 0.8765 - val_loss: 0.4532 - val_acc: 0.8025 - val_f1_m: 0.7999 - val_precision_m: 0.8000 - val_recall_m: 0.8400 - 77ms/epoch - 4ms/step
Epoch 22/25
19/19 - 0s - loss: 0.3386 - acc: 0.8660 - f1_m: 0.8575 - precision_m: 0.8522 - recall_m: 0.8793 - val_loss: 0.4549 - val_acc: 0.7963 - val_f1_m: 0.7931 - val_precision_m: 0.7987 - val_recall_m: 0.8300 - 74ms/epoch - 4ms/step
Epoch 23/25
19/19 - 0s - loss: 0.3338 - acc: 0.8697 - f1_m: 0.8610 - precision_m: 0.8565 - recall_m: 0.8819 - val_loss: 0.4561 - val_acc: 0.7963 - val_f1_m: 0.7931 - val_precision_m: 0.7987 - val_recall_m: 0.8300 - 77ms/epoch - 4ms/step
Epoch 24/25
19/19 - 0s - loss: 0.3285 - acc: 0.8722 - f1_m: 0.8634 - precision_m: 0.8602 - recall_m: 0.8824 - val_loss: 0.4580 - val_acc: 0.7963 - val_f1_m: 0.7931 - val_precision_m: 0.7987 - val_recall_m: 0.8300 - 73ms/epoch - 4ms/step
Epoch 25/25
19/19 - 0s - loss: 0.3237 - acc: 0.8747 - f1_m: 0.8663 - precision_m: 0.8634 - recall_m: 0.8854 - val_loss: 0.4600 - val_acc: 0.7963 - val_f1_m: 0.7910 - val_precision_m: 0.8077 - val_recall_m: 0.8200 - 82ms/epoch - 4ms/step
In [ ]:
from matplotlib import pyplot as plt

%matplotlib inline

plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_4.history['f1_m'])

plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_4.history['val_f1_m'])
plt.title('Validation')

plt.subplot(2,2,3)
plt.plot(history_4.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')

plt.subplot(2,2,4)
plt.plot(history_4.history['val_loss'])
plt.xlabel('epochs')
Out[ ]:
Text(0.5, 0, 'epochs')
No description has been provided for this image

Generally, I saw my model convergence all over the map on multiple runs. I changed the number of epochs after a few test runs. Now I'm converging between 10-15 epochs. However I'd like to see my loss getting overall a little lower to be confident I'm getting better performance. I'm not seeing that between the orignal model 2 and this altered version. I do see overtraining occurring in this model as my validation training loss begins to trend back upward after about 12 epochs.

To verify whether my models are really different, I'll peform a statistical analysis of the two.

In [ ]:
#  Vizualize some metrics associated with this model

# Source: Modified from in-class lecture
# Use the sklearn metrics here, if you want to
# from sklearn import metrics as mt

# y_test = tf.concat([y for x, y in ds_test], axis=0)
# y_test = y_test.numpy()

# now lets see how well the model performed
yhat_proba_4 = training_model_4.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions 
yhat_4 = np.round(yhat_proba_4.squeeze()) # round to get binary class

conf_mat_4 = mt.confusion_matrix(y_test, yhat_4)

print(conf_mat_4)
print(mt.classification_report(y_test,yhat_4))

# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309.  VitalBook file.
# Create pandas dataframe
conf_df_4 = pd.DataFrame(conf_mat_4, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())

# Create heatmap
sns.heatmap(conf_df_4, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 1s 2ms/step
[[64 16]
 [17 65]]
              precision    recall  f1-score   support

           0       0.79      0.80      0.80        80
           1       0.80      0.79      0.80        82

    accuracy                           0.80       162
   macro avg       0.80      0.80      0.80       162
weighted avg       0.80      0.80      0.80       162

No description has been provided for this image

Comparison between model 2 and model 4¶

To determine which model truly performed better I need to understand whether the models are really different from one another.

Initially, I split my data using KFold and 6 folds, then used those training and test sets for each model. So the models are all trained on the exact same data. So instead of looking at the f1_scores for each fold, I'm going to compare the f1_scores measured during model fitting.

In another dataset, this could potentially cause me problems as I'm comparing the results of the entire dataset to one another as opposed to each fold. However, because my prediction classes are perfectly balanced, and the data for each model is identical, I believe this will still yield some useful information for comparison.

Note, ideally I'd run a cross_val_score here but I couldn't get it to work with the keras models despite multiple efforts.

In [ ]:
from scipy.stats import t

# Get the histories of val_f1 scores from my two models for comparison
f1_score_model_2 = history_2.history['val_f1_m']
f1_score_model_4 = history_4.history['val_f1_m']

# get error rates for both model's f1 scores
model_2_err = [1 - f1 for f1 in f1_score_model_2]
model_4_err = [1 - f1 for f1 in f1_score_model_4]

d = []
for err in range(len(model_2_err)):
    d.append(model_2_err[err] - model_4_err[err])

dbar = sum(d) / len(d)
stdtot = np.std(d)

epochs = 12
confidence_level = 0.95
degrees_of_freedom = epochs

# Calculate the critical value, t
t = t.ppf((1 + confidence_level) / 2, degrees_of_freedom)

# print(f'The error of the three models is\n', acc1.mean(), '\n', acc2.mean(), '\n', acc3.mean())
print('Range of:', dbar-t*stdtot,dbar+t*stdtot, 'between model 2 and model 4')
Range of: -0.3139904159770698 0.12504712725344203 between model 2 and model 4

An interesting note about my statistical analysis. I ran these models multiple times. There were times in which the models showed there was no statistical difference (range contained 0), and there were times it showed there was (range did not contain 0). This was with no model changes at all. That doesn't fill me with confidence that I've implemented everything correctly. And gives me even less faith in statistics. Here are a few values from sample runs:

  • Range of: -0.33352213888834137 0.20063908685714227 between model 2 and model 4 -- Statistically different
  • Range of: -0.086830071054072 0.0883182957193959 between model 2 and model 4 -- Statistically not different
  • Range of: -0.08561457731365749 0.06702024875123888 between model 2 and model 4 -- Statistically different

Because of these mixed results I'll speak to these results from two angles. If the two models are NOT statistically different, then any given run may result in values that are better or worse than one another, and therefore either model performs roughly the same. However, if we take the opposite case that they ARE statistically different, then the difference could only be minor as the results overalp so frequently.

Performance Comparision vs MLP (Deep Side)¶

For this comparison I'll be using just the deep side of my wide and deep neural network along with my best performing model. I've mentioned the interchangeability of "best model" several times throughout this notebook, so which model I select does not truly make a large difference here. For the sake of consistency, I'm going to use model 2 as I did for my previous section's comparisons.

In [ ]:
# Source: Modified from in-class lecture to match my dataset
from tensorflow.keras.utils import FeatureSpace

feature_space_mlp = FeatureSpace(
    features={
        # Categorical feature encoded as string
        "type_of_breast_surgery": FeatureSpace.string_categorical(num_oov_indices=0),
        "cancer_type_detailed": FeatureSpace.string_categorical(num_oov_indices=0),
        "cellularity": FeatureSpace.string_categorical(num_oov_indices=0),
        "pam50_plus_claudin-low_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        "er_status_measured_by_ihc": FeatureSpace.string_categorical(num_oov_indices=0),
        "er_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "her2_status_measured_by_snp6": FeatureSpace.string_categorical(num_oov_indices=0),
        "her2_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        # "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        "inferred_menopausal_state": FeatureSpace.string_categorical(num_oov_indices=0),
        "integrative_cluster": FeatureSpace.string_categorical(num_oov_indices=0),
        "primary_tumor_laterality": FeatureSpace.string_categorical(num_oov_indices=0),
        "oncotree_code": FeatureSpace.string_categorical(num_oov_indices=0),
        "pr_status": FeatureSpace.string_categorical(num_oov_indices=0),
        "3-gene_classifier_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
        
        # Numerical features to normalize (normalization will be learned)
        # learns the mean, variance, and if to invert
        "chemotherapy": FeatureSpace.float_normalized(),
        "hormone_therapy": FeatureSpace.float_normalized(),
        "radio_therapy": FeatureSpace.float_normalized(),
        "age_at_diagnosis": FeatureSpace.float_normalized(),
        "neoplasm_histologic_grade": FeatureSpace.float_normalized(),
        "lymph_nodes_examined_positive": FeatureSpace.float_normalized(),
        "mutation_count": FeatureSpace.float_normalized(),
        "nottingham_prognostic_index": FeatureSpace.float_normalized(),
        "overall_survival_months": FeatureSpace.float_normalized(),
        "tumor_size": FeatureSpace.float_normalized(),
        "tumor_stage": FeatureSpace.float_normalized(),
    },
    output_mode="concat", 
)


# now that we have specified the preprocessing, let's run it on the data

# create a version of the dataset that can be iterated without labels
train_ds_with_no_labels = ds_train.map(lambda x, _: x)  
feature_space_mlp.adapt(train_ds_with_no_labels) # inititalize the feature map to this data
In [ ]:
# from keras.metrics import Precision, Recall
dict_inputs = feature_space_mlp.get_inputs() # need to use unprocessed features here, to gain access to each output

# we need to create separate lists for each branch
crossed_outputs = []

# for each crossed variable, make an embedding
for col in feature_space_mlp.crossers.keys():
    
    x = setup_embedding_from_crossing(feature_space_mlp, col)
    
    # save these outputs in list to concatenate later
    crossed_outputs.append(x)
    

# now concatenate the outputs and add a fully connected layer
# wide_branch = Concatenate(name='wide_concat')(crossed_outputs)

# reset this input branch
all_deep_branch_outputs = []

# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
    x = feature_space_mlp.preprocessors[col].output
    x = tf.cast(x,float) # cast an integer as a float here
    all_deep_branch_outputs.append(x)
    
# for each categorical variable
for col in categorical_headers:
    
    # get the output tensor from ebedding layer
    x = setup_embedding_from_categorical(feature_space_mlp, col)
    
    # save these outputs in list to concatenate later
    all_deep_branch_outputs.append(x)


# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=8,activation='relu', name='deep3')(deep_branch) # Changed from 10 to 8 neurons
deep_branch = Dense(units=4,activation='relu', name='deep4')(deep_branch) # This is my new layer
deep_branch = Dense(units=1,activation='sigmoid', name='deep5')(deep_branch) # adding this sigmoid layer to make a complete MLP representation from the Deep side

training_model_mlp = keras.Model(inputs=dict_inputs, outputs=deep_branch)

# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_mlp.compile(
    optimizer="adam", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)

training_model_mlp.summary()

plot_model(
    training_model_mlp, to_file='model.png', show_shapes=True, show_layer_names=True,
    rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_39"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 type_of_breast_surgery (InputL  [(None, 1)]         0           []                               
 ayer)                                                                                            
                                                                                                  
 cancer_type_detailed (InputLay  [(None, 1)]         0           []                               
 er)                                                                                              
                                                                                                  
 cellularity (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 pam50_plus_claudin-low_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 er_status_measured_by_ihc (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 er_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 her2_status_measured_by_snp6 (  [(None, 1)]         0           []                               
 InputLayer)                                                                                      
                                                                                                  
 her2_status (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 tumor_other_histologic_subtype  [(None, 1)]         0           []                               
  (InputLayer)                                                                                    
                                                                                                  
 inferred_menopausal_state (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 integrative_cluster (InputLaye  [(None, 1)]         0           []                               
 r)                                                                                               
                                                                                                  
 primary_tumor_laterality (Inpu  [(None, 1)]         0           []                               
 tLayer)                                                                                          
                                                                                                  
 oncotree_code (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 pr_status (InputLayer)         [(None, 1)]          0           []                               
                                                                                                  
 3-gene_classifier_subtype (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 age_at_diagnosis (InputLayer)  [(None, 1)]          0           []                               
                                                                                                  
 neoplasm_histologic_grade (Inp  [(None, 1)]         0           []                               
 utLayer)                                                                                         
                                                                                                  
 lymph_nodes_examined_positive   [(None, 1)]         0           []                               
 (InputLayer)                                                                                     
                                                                                                  
 mutation_count (InputLayer)    [(None, 1)]          0           []                               
                                                                                                  
 nottingham_prognostic_index (I  [(None, 1)]         0           []                               
 nputLayer)                                                                                       
                                                                                                  
 overall_survival_months (Input  [(None, 1)]         0           []                               
 Layer)                                                                                           
                                                                                                  
 tumor_size (InputLayer)        [(None, 1)]          0           []                               
                                                                                                  
 tumor_stage (InputLayer)       [(None, 1)]          0           []                               
                                                                                                  
 chemotherapy (InputLayer)      [(None, 1)]          0           []                               
                                                                                                  
 radio_therapy (InputLayer)     [(None, 1)]          0           []                               
                                                                                                  
 hormone_therapy (InputLayer)   [(None, 1)]          0           []                               
                                                                                                  
 string_categorical_454_preproc  (None, 1)           0           ['type_of_breast_surgery[0][0]'] 
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_455_preproc  (None, 1)           0           ['cancer_type_detailed[0][0]']   
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_456_preproc  (None, 1)           0           ['cellularity[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_457_preproc  (None, 1)           0           ['pam50_plus_claudin-low_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_458_preproc  (None, 1)           0           ['er_status_measured_by_ihc[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_459_preproc  (None, 1)           0           ['er_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_460_preproc  (None, 1)           0           ['her2_status_measured_by_snp6[0]
 essor (StringLookup)                                            [0]']                            
                                                                                                  
 string_categorical_461_preproc  (None, 1)           0           ['her2_status[0][0]']            
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_462_preproc  (None, 1)           0           ['tumor_other_histologic_subtype[
 essor (StringLookup)                                            0][0]']                          
                                                                                                  
 string_categorical_463_preproc  (None, 1)           0           ['inferred_menopausal_state[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 string_categorical_464_preproc  (None, 1)           0           ['integrative_cluster[0][0]']    
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_465_preproc  (None, 1)           0           ['primary_tumor_laterality[0][0]'
 essor (StringLookup)                                            ]                                
                                                                                                  
 string_categorical_466_preproc  (None, 1)           0           ['oncotree_code[0][0]']          
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_467_preproc  (None, 1)           0           ['pr_status[0][0]']              
 essor (StringLookup)                                                                             
                                                                                                  
 string_categorical_468_preproc  (None, 1)           0           ['3-gene_classifier_subtype[0][0]
 essor (StringLookup)                                            ']                               
                                                                                                  
 float_normalized_328_preproces  (None, 1)           3           ['age_at_diagnosis[0][0]']       
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_329_preproces  (None, 1)           3           ['neoplasm_histologic_grade[0][0]
 sor (Normalization)                                             ']                               
                                                                                                  
 float_normalized_330_preproces  (None, 1)           3           ['lymph_nodes_examined_positive[0
 sor (Normalization)                                             ][0]']                           
                                                                                                  
 float_normalized_331_preproces  (None, 1)           3           ['mutation_count[0][0]']         
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_332_preproces  (None, 1)           3           ['nottingham_prognostic_index[0][
 sor (Normalization)                                             0]']                             
                                                                                                  
 float_normalized_333_preproces  (None, 1)           3           ['overall_survival_months[0][0]']
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_334_preproces  (None, 1)           3           ['tumor_size[0][0]']             
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_335_preproces  (None, 1)           3           ['tumor_stage[0][0]']            
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_325_preproces  (None, 1)           3           ['chemotherapy[0][0]']           
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_327_preproces  (None, 1)           3           ['radio_therapy[0][0]']          
 sor (Normalization)                                                                              
                                                                                                  
 float_normalized_326_preproces  (None, 1)           3           ['hormone_therapy[0][0]']        
 sor (Normalization)                                                                              
                                                                                                  
 type_of_breast_surgery_embed (  (None, 1, 1)        2           ['string_categorical_454_preproce
 Embedding)                                                      ssor[0][0]']                     
                                                                                                  
 cancer_type_detailed_embed (Em  (None, 1, 2)        10          ['string_categorical_455_preproce
 bedding)                                                        ssor[0][0]']                     
                                                                                                  
 cellularity_embed (Embedding)  (None, 1, 1)         3           ['string_categorical_456_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 pam50_plus_claudin-low_subtype  (None, 1, 2)        14          ['string_categorical_457_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 er_status_measured_by_ihc_embe  (None, 1, 1)        2           ['string_categorical_458_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 er_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_459_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 her2_status_measured_by_snp6_e  (None, 1, 2)        8           ['string_categorical_460_preproce
 mbed (Embedding)                                                ssor[0][0]']                     
                                                                                                  
 her2_status_embed (Embedding)  (None, 1, 1)         2           ['string_categorical_461_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 tumor_other_histologic_subtype  (None, 1, 2)        14          ['string_categorical_462_preproce
 _embed (Embedding)                                              ssor[0][0]']                     
                                                                                                  
 inferred_menopausal_state_embe  (None, 1, 1)        2           ['string_categorical_463_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 integrative_cluster_embed (Emb  (None, 1, 3)        33          ['string_categorical_464_preproce
 edding)                                                         ssor[0][0]']                     
                                                                                                  
 primary_tumor_laterality_embed  (None, 1, 1)        2           ['string_categorical_465_preproce
  (Embedding)                                                    ssor[0][0]']                     
                                                                                                  
 oncotree_code_embed (Embedding  (None, 1, 2)        10          ['string_categorical_466_preproce
 )                                                               ssor[0][0]']                     
                                                                                                  
 pr_status_embed (Embedding)    (None, 1, 1)         2           ['string_categorical_467_preproce
                                                                 ssor[0][0]']                     
                                                                                                  
 3-gene_classifier_subtype_embe  (None, 1, 2)        8           ['string_categorical_468_preproce
 d (Embedding)                                                   ssor[0][0]']                     
                                                                                                  
 tf.cast_429 (TFOpLambda)       (None, 1)            0           ['float_normalized_328_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_430 (TFOpLambda)       (None, 1)            0           ['float_normalized_329_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_431 (TFOpLambda)       (None, 1)            0           ['float_normalized_330_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_432 (TFOpLambda)       (None, 1)            0           ['float_normalized_331_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_433 (TFOpLambda)       (None, 1)            0           ['float_normalized_332_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_434 (TFOpLambda)       (None, 1)            0           ['float_normalized_333_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_435 (TFOpLambda)       (None, 1)            0           ['float_normalized_334_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_436 (TFOpLambda)       (None, 1)            0           ['float_normalized_335_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_437 (TFOpLambda)       (None, 1)            0           ['float_normalized_325_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_438 (TFOpLambda)       (None, 1)            0           ['float_normalized_327_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 tf.cast_439 (TFOpLambda)       (None, 1)            0           ['float_normalized_326_preprocess
                                                                 or[0][0]']                       
                                                                                                  
 flatten_689 (Flatten)          (None, 1)            0           ['type_of_breast_surgery_embed[0]
                                                                 [0]']                            
                                                                                                  
 flatten_690 (Flatten)          (None, 2)            0           ['cancer_type_detailed_embed[0][0
                                                                 ]']                              
                                                                                                  
 flatten_691 (Flatten)          (None, 1)            0           ['cellularity_embed[0][0]']      
                                                                                                  
 flatten_692 (Flatten)          (None, 2)            0           ['pam50_plus_claudin-low_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_693 (Flatten)          (None, 1)            0           ['er_status_measured_by_ihc_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_694 (Flatten)          (None, 1)            0           ['er_status_embed[0][0]']        
                                                                                                  
 flatten_695 (Flatten)          (None, 2)            0           ['her2_status_measured_by_snp6_em
                                                                 bed[0][0]']                      
                                                                                                  
 flatten_696 (Flatten)          (None, 1)            0           ['her2_status_embed[0][0]']      
                                                                                                  
 flatten_697 (Flatten)          (None, 2)            0           ['tumor_other_histologic_subtype_
                                                                 embed[0][0]']                    
                                                                                                  
 flatten_698 (Flatten)          (None, 1)            0           ['inferred_menopausal_state_embed
                                                                 [0][0]']                         
                                                                                                  
 flatten_699 (Flatten)          (None, 3)            0           ['integrative_cluster_embed[0][0]
                                                                 ']                               
                                                                                                  
 flatten_700 (Flatten)          (None, 1)            0           ['primary_tumor_laterality_embed[
                                                                 0][0]']                          
                                                                                                  
 flatten_701 (Flatten)          (None, 2)            0           ['oncotree_code_embed[0][0]']    
                                                                                                  
 flatten_702 (Flatten)          (None, 1)            0           ['pr_status_embed[0][0]']        
                                                                                                  
 flatten_703 (Flatten)          (None, 2)            0           ['3-gene_classifier_subtype_embed
                                                                 [0][0]']                         
                                                                                                  
 embed_concat (Concatenate)     (None, 34)           0           ['tf.cast_429[0][0]',            
                                                                  'tf.cast_430[0][0]',            
                                                                  'tf.cast_431[0][0]',            
                                                                  'tf.cast_432[0][0]',            
                                                                  'tf.cast_433[0][0]',            
                                                                  'tf.cast_434[0][0]',            
                                                                  'tf.cast_435[0][0]',            
                                                                  'tf.cast_436[0][0]',            
                                                                  'tf.cast_437[0][0]',            
                                                                  'tf.cast_438[0][0]',            
                                                                  'tf.cast_439[0][0]',            
                                                                  'flatten_689[0][0]',            
                                                                  'flatten_690[0][0]',            
                                                                  'flatten_691[0][0]',            
                                                                  'flatten_692[0][0]',            
                                                                  'flatten_693[0][0]',            
                                                                  'flatten_694[0][0]',            
                                                                  'flatten_695[0][0]',            
                                                                  'flatten_696[0][0]',            
                                                                  'flatten_697[0][0]',            
                                                                  'flatten_698[0][0]',            
                                                                  'flatten_699[0][0]',            
                                                                  'flatten_700[0][0]',            
                                                                  'flatten_701[0][0]',            
                                                                  'flatten_702[0][0]',            
                                                                  'flatten_703[0][0]']            
                                                                                                  
 deep1 (Dense)                  (None, 34)           1190        ['embed_concat[0][0]']           
                                                                                                  
 deep2 (Dense)                  (None, 17)           595         ['deep1[0][0]']                  
                                                                                                  
 deep3 (Dense)                  (None, 8)            144         ['deep2[0][0]']                  
                                                                                                  
 deep4 (Dense)                  (None, 4)            36          ['deep3[0][0]']                  
                                                                                                  
 deep5 (Dense)                  (None, 1)            5           ['deep4[0][0]']                  
                                                                                                  
==================================================================================================
Total params: 2,117
Trainable params: 2,084
Non-trainable params: 33
__________________________________________________________________________________________________
Out[ ]:
No description has been provided for this image
In [ ]:
# train using the already processed features
history_mlp = training_model_mlp.fit(
    ds_train, epochs=35, validation_data=ds_test, verbose=2
) # changed number of epochs
Epoch 1/35
19/19 - 6s - loss: 0.6599 - acc: 0.5645 - f1_m: 0.2866 - precision_m: 0.6511 - recall_m: 0.1948 - val_loss: 0.6304 - val_acc: 0.6543 - val_f1_m: 0.5149 - val_precision_m: 0.8018 - val_recall_m: 0.4025 - 6s/epoch - 311ms/step
Epoch 2/35
19/19 - 0s - loss: 0.6023 - acc: 0.7370 - f1_m: 0.6602 - precision_m: 0.8006 - recall_m: 0.5892 - val_loss: 0.5950 - val_acc: 0.7778 - val_f1_m: 0.7736 - val_precision_m: 0.7866 - val_recall_m: 0.8067 - 71ms/epoch - 4ms/step
Epoch 3/35
19/19 - 0s - loss: 0.5685 - acc: 0.7742 - f1_m: 0.7608 - precision_m: 0.7758 - recall_m: 0.7939 - val_loss: 0.5600 - val_acc: 0.8210 - val_f1_m: 0.8179 - val_precision_m: 0.8201 - val_recall_m: 0.8700 - 70ms/epoch - 4ms/step
Epoch 4/35
19/19 - 0s - loss: 0.5355 - acc: 0.7816 - f1_m: 0.7721 - precision_m: 0.7862 - recall_m: 0.8088 - val_loss: 0.5232 - val_acc: 0.8086 - val_f1_m: 0.8076 - val_precision_m: 0.8072 - val_recall_m: 0.8600 - 71ms/epoch - 4ms/step
Epoch 5/35
19/19 - 0s - loss: 0.5016 - acc: 0.8040 - f1_m: 0.7949 - precision_m: 0.8070 - recall_m: 0.8272 - val_loss: 0.4926 - val_acc: 0.7963 - val_f1_m: 0.7979 - val_precision_m: 0.7853 - val_recall_m: 0.8600 - 70ms/epoch - 4ms/step
Epoch 6/35
19/19 - 0s - loss: 0.4780 - acc: 0.8065 - f1_m: 0.7978 - precision_m: 0.8048 - recall_m: 0.8335 - val_loss: 0.4691 - val_acc: 0.8025 - val_f1_m: 0.8030 - val_precision_m: 0.7958 - val_recall_m: 0.8600 - 78ms/epoch - 4ms/step
Epoch 7/35
19/19 - 0s - loss: 0.4594 - acc: 0.8164 - f1_m: 0.8076 - precision_m: 0.8070 - recall_m: 0.8477 - val_loss: 0.4537 - val_acc: 0.7963 - val_f1_m: 0.7981 - val_precision_m: 0.7961 - val_recall_m: 0.8525 - 80ms/epoch - 4ms/step
Epoch 8/35
19/19 - 0s - loss: 0.4458 - acc: 0.8201 - f1_m: 0.8107 - precision_m: 0.8083 - recall_m: 0.8510 - val_loss: 0.4444 - val_acc: 0.7963 - val_f1_m: 0.7981 - val_precision_m: 0.7961 - val_recall_m: 0.8525 - 92ms/epoch - 5ms/step
Epoch 9/35
19/19 - 0s - loss: 0.4366 - acc: 0.8213 - f1_m: 0.8124 - precision_m: 0.8086 - recall_m: 0.8538 - val_loss: 0.4384 - val_acc: 0.7963 - val_f1_m: 0.7981 - val_precision_m: 0.7961 - val_recall_m: 0.8525 - 64ms/epoch - 3ms/step
Epoch 10/35
19/19 - 0s - loss: 0.4282 - acc: 0.8226 - f1_m: 0.8136 - precision_m: 0.8107 - recall_m: 0.8542 - val_loss: 0.4344 - val_acc: 0.7901 - val_f1_m: 0.7884 - val_precision_m: 0.7978 - val_recall_m: 0.8300 - 75ms/epoch - 4ms/step
Epoch 11/35
19/19 - 0s - loss: 0.4210 - acc: 0.8300 - f1_m: 0.8192 - precision_m: 0.8189 - recall_m: 0.8542 - val_loss: 0.4320 - val_acc: 0.7840 - val_f1_m: 0.7824 - val_precision_m: 0.7973 - val_recall_m: 0.8200 - 73ms/epoch - 4ms/step
Epoch 12/35
19/19 - 0s - loss: 0.4135 - acc: 0.8313 - f1_m: 0.8193 - precision_m: 0.8207 - recall_m: 0.8512 - val_loss: 0.4312 - val_acc: 0.7901 - val_f1_m: 0.7897 - val_precision_m: 0.7993 - val_recall_m: 0.8325 - 75ms/epoch - 4ms/step
Epoch 13/35
19/19 - 0s - loss: 0.4080 - acc: 0.8325 - f1_m: 0.8208 - precision_m: 0.8202 - recall_m: 0.8543 - val_loss: 0.4302 - val_acc: 0.7901 - val_f1_m: 0.7897 - val_precision_m: 0.7993 - val_recall_m: 0.8325 - 73ms/epoch - 4ms/step
Epoch 14/35
19/19 - 0s - loss: 0.4024 - acc: 0.8375 - f1_m: 0.8259 - precision_m: 0.8238 - recall_m: 0.8592 - val_loss: 0.4299 - val_acc: 0.7901 - val_f1_m: 0.7897 - val_precision_m: 0.7993 - val_recall_m: 0.8325 - 71ms/epoch - 4ms/step
Epoch 15/35
19/19 - 0s - loss: 0.3973 - acc: 0.8412 - f1_m: 0.8295 - precision_m: 0.8260 - recall_m: 0.8635 - val_loss: 0.4300 - val_acc: 0.7901 - val_f1_m: 0.7897 - val_precision_m: 0.7993 - val_recall_m: 0.8325 - 71ms/epoch - 4ms/step
Epoch 16/35
19/19 - 0s - loss: 0.3928 - acc: 0.8437 - f1_m: 0.8324 - precision_m: 0.8288 - recall_m: 0.8666 - val_loss: 0.4303 - val_acc: 0.7963 - val_f1_m: 0.7945 - val_precision_m: 0.8047 - val_recall_m: 0.8325 - 70ms/epoch - 4ms/step
Epoch 17/35
19/19 - 0s - loss: 0.3873 - acc: 0.8462 - f1_m: 0.8351 - precision_m: 0.8303 - recall_m: 0.8672 - val_loss: 0.4300 - val_acc: 0.8025 - val_f1_m: 0.7995 - val_precision_m: 0.8106 - val_recall_m: 0.8325 - 71ms/epoch - 4ms/step
Epoch 18/35
19/19 - 0s - loss: 0.3825 - acc: 0.8499 - f1_m: 0.8391 - precision_m: 0.8347 - recall_m: 0.8703 - val_loss: 0.4296 - val_acc: 0.8086 - val_f1_m: 0.8056 - val_precision_m: 0.8112 - val_recall_m: 0.8425 - 72ms/epoch - 4ms/step
Epoch 19/35
19/19 - 0s - loss: 0.3777 - acc: 0.8548 - f1_m: 0.8433 - precision_m: 0.8394 - recall_m: 0.8722 - val_loss: 0.4293 - val_acc: 0.8086 - val_f1_m: 0.8056 - val_precision_m: 0.8112 - val_recall_m: 0.8425 - 69ms/epoch - 4ms/step
Epoch 20/35
19/19 - 0s - loss: 0.3731 - acc: 0.8586 - f1_m: 0.8463 - precision_m: 0.8424 - recall_m: 0.8735 - val_loss: 0.4293 - val_acc: 0.8086 - val_f1_m: 0.8056 - val_precision_m: 0.8112 - val_recall_m: 0.8425 - 69ms/epoch - 4ms/step
Epoch 21/35
19/19 - 0s - loss: 0.3684 - acc: 0.8598 - f1_m: 0.8482 - precision_m: 0.8451 - recall_m: 0.8748 - val_loss: 0.4287 - val_acc: 0.8086 - val_f1_m: 0.8056 - val_precision_m: 0.8112 - val_recall_m: 0.8425 - 70ms/epoch - 4ms/step
Epoch 22/35
19/19 - 0s - loss: 0.3648 - acc: 0.8623 - f1_m: 0.8505 - precision_m: 0.8463 - recall_m: 0.8772 - val_loss: 0.4294 - val_acc: 0.8148 - val_f1_m: 0.8113 - val_precision_m: 0.8117 - val_recall_m: 0.8525 - 70ms/epoch - 4ms/step
Epoch 23/35
19/19 - 0s - loss: 0.3593 - acc: 0.8623 - f1_m: 0.8512 - precision_m: 0.8448 - recall_m: 0.8800 - val_loss: 0.4301 - val_acc: 0.8148 - val_f1_m: 0.8113 - val_precision_m: 0.8117 - val_recall_m: 0.8525 - 67ms/epoch - 4ms/step
Epoch 24/35
19/19 - 0s - loss: 0.3556 - acc: 0.8648 - f1_m: 0.8534 - precision_m: 0.8464 - recall_m: 0.8818 - val_loss: 0.4301 - val_acc: 0.8148 - val_f1_m: 0.8113 - val_precision_m: 0.8117 - val_recall_m: 0.8525 - 69ms/epoch - 4ms/step
Epoch 25/35
19/19 - 0s - loss: 0.3511 - acc: 0.8648 - f1_m: 0.8536 - precision_m: 0.8464 - recall_m: 0.8818 - val_loss: 0.4306 - val_acc: 0.8025 - val_f1_m: 0.8011 - val_precision_m: 0.7956 - val_recall_m: 0.8525 - 74ms/epoch - 4ms/step
Epoch 26/35
19/19 - 0s - loss: 0.3479 - acc: 0.8672 - f1_m: 0.8563 - precision_m: 0.8473 - recall_m: 0.8865 - val_loss: 0.4293 - val_acc: 0.8025 - val_f1_m: 0.8011 - val_precision_m: 0.7956 - val_recall_m: 0.8525 - 73ms/epoch - 4ms/step
Epoch 27/35
19/19 - 0s - loss: 0.3427 - acc: 0.8672 - f1_m: 0.8561 - precision_m: 0.8468 - recall_m: 0.8865 - val_loss: 0.4292 - val_acc: 0.7963 - val_f1_m: 0.7962 - val_precision_m: 0.7864 - val_recall_m: 0.8525 - 73ms/epoch - 4ms/step
Epoch 28/35
19/19 - 0s - loss: 0.3387 - acc: 0.8734 - f1_m: 0.8635 - precision_m: 0.8499 - recall_m: 0.8983 - val_loss: 0.4298 - val_acc: 0.7901 - val_f1_m: 0.7914 - val_precision_m: 0.7810 - val_recall_m: 0.8525 - 70ms/epoch - 4ms/step
Epoch 29/35
19/19 - 0s - loss: 0.3345 - acc: 0.8734 - f1_m: 0.8635 - precision_m: 0.8499 - recall_m: 0.8983 - val_loss: 0.4298 - val_acc: 0.7901 - val_f1_m: 0.7914 - val_precision_m: 0.7810 - val_recall_m: 0.8525 - 71ms/epoch - 4ms/step
Epoch 30/35
19/19 - 0s - loss: 0.3312 - acc: 0.8734 - f1_m: 0.8635 - precision_m: 0.8499 - recall_m: 0.8983 - val_loss: 0.4303 - val_acc: 0.7901 - val_f1_m: 0.7914 - val_precision_m: 0.7810 - val_recall_m: 0.8525 - 71ms/epoch - 4ms/step
Epoch 31/35
19/19 - 0s - loss: 0.3267 - acc: 0.8784 - f1_m: 0.8682 - precision_m: 0.8585 - recall_m: 0.8983 - val_loss: 0.4310 - val_acc: 0.7901 - val_f1_m: 0.7914 - val_precision_m: 0.7810 - val_recall_m: 0.8525 - 78ms/epoch - 4ms/step
Epoch 32/35
19/19 - 0s - loss: 0.3223 - acc: 0.8797 - f1_m: 0.8707 - precision_m: 0.8572 - recall_m: 0.9043 - val_loss: 0.4319 - val_acc: 0.7963 - val_f1_m: 0.7982 - val_precision_m: 0.7823 - val_recall_m: 0.8625 - 71ms/epoch - 4ms/step
Epoch 33/35
19/19 - 0s - loss: 0.3179 - acc: 0.8834 - f1_m: 0.8746 - precision_m: 0.8619 - recall_m: 0.9074 - val_loss: 0.4326 - val_acc: 0.7963 - val_f1_m: 0.7982 - val_precision_m: 0.7823 - val_recall_m: 0.8625 - 76ms/epoch - 4ms/step
Epoch 34/35
19/19 - 0s - loss: 0.3142 - acc: 0.8859 - f1_m: 0.8769 - precision_m: 0.8652 - recall_m: 0.9074 - val_loss: 0.4327 - val_acc: 0.8025 - val_f1_m: 0.8030 - val_precision_m: 0.7877 - val_recall_m: 0.8625 - 76ms/epoch - 4ms/step
Epoch 35/35
19/19 - 0s - loss: 0.3087 - acc: 0.8883 - f1_m: 0.8786 - precision_m: 0.8669 - recall_m: 0.9086 - val_loss: 0.4340 - val_acc: 0.8025 - val_f1_m: 0.8032 - val_precision_m: 0.7915 - val_recall_m: 0.8625 - 78ms/epoch - 4ms/step
In [ ]:
from matplotlib import pyplot as plt

%matplotlib inline

plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_mlp.history['f1_m'])

plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_mlp.history['val_f1_m'])
plt.title('Validation')

plt.subplot(2,2,3)
plt.plot(history_mlp.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')

plt.subplot(2,2,4)
plt.plot(history_mlp.history['val_loss'])
plt.xlabel('epochs')
Out[ ]:
Text(0.5, 0, 'epochs')
No description has been provided for this image
In [ ]:
#  Vizualize some metrics associated with this model
# Source: Modified from in-class lecture

# now lets see how well the model performed
yhat_proba_mlp = training_model_mlp.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions 
yhat_4 = np.round(yhat_proba_4.squeeze()) # round to get binary class

conf_mat_4 = mt.confusion_matrix(y_test, yhat_4)

print(conf_mat_4)
print(mt.classification_report(y_test,yhat_4))

# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309.  VitalBook file.
# Create pandas dataframe
conf_df_4 = pd.DataFrame(conf_mat_4, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())

# Create heatmap
sns.heatmap(conf_df_4, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 0s 2ms/step
[[64 16]
 [17 65]]
              precision    recall  f1-score   support

           0       0.79      0.80      0.80        80
           1       0.80      0.79      0.80        82

    accuracy                           0.80       162
   macro avg       0.80      0.80      0.80       162
weighted avg       0.80      0.80      0.80       162

No description has been provided for this image
In [ ]:
from scipy.stats import t

# Get the histories of val_f1 scores from my two models for comparison
f1_score_model_2 = history_2.history['val_f1_m']
f1_score_model_4 = history_mlp.history['val_f1_m']

# get error rates for both model's f1 scores
model_2_err = [1 - f1 for f1 in f1_score_model_2]
model_4_err = [1 - f1 for f1 in f1_score_model_4]

d = []
for err in range(len(model_2_err)):
    d.append(model_2_err[err] - model_4_err[err])

dbar = sum(d) / len(d)
stdtot = np.std(d)

epochs = 12
confidence_level = 0.95
degrees_of_freedom = epochs

# Calculate the critical value, t
t = t.ppf((1 + confidence_level) / 2, degrees_of_freedom)

# print(f'The error of the three models is\n', acc1.mean(), '\n', acc2.mean(), '\n', acc3.mean())
print('Range of:', dbar-t*stdtot,dbar+t*stdtot, 'between model 2 and the mlp model')
Range of: -0.08233255758287983 0.07009268019996559 between model 2 and the mlp model

Everything here compares well with previous results. No significant change to the confusion matrix, and no significant change to my statistical analysis. This result indicates the models are not statistically different from one another.

Now let's check out the ROC curve for these two models.

In [ ]:
# Load libraries
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split

# # Create feature matrix and target vector
# features, target = make_classification(n_samples=10000,
#                                        n_features=10,
#                                        n_classes=2,
#                                        n_informative=3,
#                                        random_state=3)

# # Split into training and test sets
# features_train, features_test, target_train, target_test = train_test_split(
#     features, target, test_size=0.1, random_state=1)

# # Create classifier
# logit = LogisticRegression()

# # Train model
# logit.fit(features_train, target_train)

# # Get predicted probabilities
# target_probabilities = logit.predict_proba(features_test)[:,1]

# Create true and false positive rates
false_positive_rate_mlp, true_positive_rate_mlp, threshold = roc_curve(y_test,
                                                               yhat_proba_mlp)

false_positive_rate_2, true_positive_rate_2, threshold = roc_curve(y_test,
                                                               yhat_proba_2)


# Plot ROC curve
plt.title("Receiver Operating Characteristic")
plt.plot(false_positive_rate_mlp, true_positive_rate_mlp, label='MLP curve')
plt.plot(false_positive_rate_2, true_positive_rate_2, label='Model 2')
plt.plot([0, 1], ls="--")
plt.plot([0, 0], [1, 0] , c=".7"), plt.plot([1, 1] , c=".7")
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
plt.legend()
plt.show()
No description has been provided for this image

The resulting ROC curve is increadibly close which aligns with the analysis to show that there is little to no statistical difference in these models. If I had to pick, my MLP only model slightly outperforms as the area under the curve is larger, though it appears very slight.

Takeaways¶

This was an interesting dataset to analyze and try to understnad how well we could predict the result. At roughly an 80 to 84% F1 score I'm not confident I would deploy this model for use. While it could be a guiding point in discussions on patient outcome, it contains enough error that I would be hesitant to rely on it as a predictor of outcomes. A further exploration of some of the genomic features may be warranted to see if they would lend additional insight into this analysis.